mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-12-29 09:32:21 +00:00
Merge branch 'main' into add-watsonx-inference-adapter
This commit is contained in:
commit
28e6c8478b
308 changed files with 33749 additions and 5102 deletions
|
|
@ -4,14 +4,14 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import Dict
|
||||
from typing import Any, Dict
|
||||
|
||||
from llama_stack.distribution.datatypes import Api, ProviderSpec
|
||||
from llama_stack.distribution.datatypes import Api
|
||||
|
||||
from .config import MetaReferenceAgentsImplConfig
|
||||
|
||||
|
||||
async def get_provider_impl(config: MetaReferenceAgentsImplConfig, deps: Dict[Api, ProviderSpec]):
|
||||
async def get_provider_impl(config: MetaReferenceAgentsImplConfig, deps: Dict[Api, Any]):
|
||||
from .agents import MetaReferenceAgentsImpl
|
||||
|
||||
impl = MetaReferenceAgentsImpl(
|
||||
|
|
|
|||
|
|
@ -11,8 +11,8 @@ import re
|
|||
import secrets
|
||||
import string
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
from typing import Any, AsyncGenerator, Dict, List, Optional, Tuple, Union
|
||||
from datetime import datetime, timezone
|
||||
from typing import AsyncGenerator, List, Optional, Union
|
||||
from urllib.parse import urlparse
|
||||
|
||||
import httpx
|
||||
|
|
@ -153,7 +153,6 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
messages.append(
|
||||
ToolResponseMessage(
|
||||
call_id=response.call_id,
|
||||
tool_name=response.tool_name,
|
||||
content=response.content,
|
||||
)
|
||||
)
|
||||
|
|
@ -181,7 +180,8 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
return messages
|
||||
|
||||
async def create_and_execute_turn(self, request: AgentTurnCreateRequest) -> AsyncGenerator:
|
||||
with tracing.span("create_and_execute_turn") as span:
|
||||
await self._initialize_tools(request.toolgroups)
|
||||
async with tracing.span("create_and_execute_turn") as span:
|
||||
span.set_attribute("session_id", request.session_id)
|
||||
span.set_attribute("agent_id", self.agent_id)
|
||||
span.set_attribute("request", request.model_dump_json())
|
||||
|
|
@ -191,7 +191,8 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
yield chunk
|
||||
|
||||
async def resume_turn(self, request: AgentTurnResumeRequest) -> AsyncGenerator:
|
||||
with tracing.span("resume_turn") as span:
|
||||
await self._initialize_tools()
|
||||
async with tracing.span("resume_turn") as span:
|
||||
span.set_attribute("agent_id", self.agent_id)
|
||||
span.set_attribute("session_id", request.session_id)
|
||||
span.set_attribute("turn_id", request.turn_id)
|
||||
|
|
@ -218,18 +219,9 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
steps = []
|
||||
messages = await self.get_messages_from_turns(turns)
|
||||
if is_resume:
|
||||
if isinstance(request.tool_responses[0], ToolResponseMessage):
|
||||
tool_response_messages = request.tool_responses
|
||||
tool_responses = [
|
||||
ToolResponse(call_id=x.call_id, tool_name=x.tool_name, content=x.content)
|
||||
for x in request.tool_responses
|
||||
]
|
||||
else:
|
||||
tool_response_messages = [
|
||||
ToolResponseMessage(call_id=x.call_id, tool_name=x.tool_name, content=x.content)
|
||||
for x in request.tool_responses
|
||||
]
|
||||
tool_responses = request.tool_responses
|
||||
tool_response_messages = [
|
||||
ToolResponseMessage(call_id=x.call_id, content=x.content) for x in request.tool_responses
|
||||
]
|
||||
messages.extend(tool_response_messages)
|
||||
last_turn = turns[-1]
|
||||
last_turn_messages = self.turn_to_messages(last_turn)
|
||||
|
|
@ -247,12 +239,12 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
in_progress_tool_call_step = await self.storage.get_in_progress_tool_call_step(
|
||||
request.session_id, request.turn_id
|
||||
)
|
||||
now = datetime.now().astimezone().isoformat()
|
||||
now = datetime.now(timezone.utc).isoformat()
|
||||
tool_execution_step = ToolExecutionStep(
|
||||
step_id=(in_progress_tool_call_step.step_id if in_progress_tool_call_step else str(uuid.uuid4())),
|
||||
turn_id=request.turn_id,
|
||||
tool_calls=(in_progress_tool_call_step.tool_calls if in_progress_tool_call_step else []),
|
||||
tool_responses=tool_responses,
|
||||
tool_responses=request.tool_responses,
|
||||
completed_at=now,
|
||||
started_at=(in_progress_tool_call_step.started_at if in_progress_tool_call_step else now),
|
||||
)
|
||||
|
|
@ -272,7 +264,7 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
start_time = last_turn.started_at
|
||||
else:
|
||||
messages.extend(request.messages)
|
||||
start_time = datetime.now().astimezone().isoformat()
|
||||
start_time = datetime.now(timezone.utc).isoformat()
|
||||
input_messages = request.messages
|
||||
|
||||
output_message = None
|
||||
|
|
@ -283,7 +275,6 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
sampling_params=self.agent_config.sampling_params,
|
||||
stream=request.stream,
|
||||
documents=request.documents if not is_resume else None,
|
||||
toolgroups_for_turn=request.toolgroups if not is_resume else None,
|
||||
):
|
||||
if isinstance(chunk, CompletionMessage):
|
||||
output_message = chunk
|
||||
|
|
@ -304,7 +295,7 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
input_messages=input_messages,
|
||||
output_message=output_message,
|
||||
started_at=start_time,
|
||||
completed_at=datetime.now().astimezone().isoformat(),
|
||||
completed_at=datetime.now(timezone.utc).isoformat(),
|
||||
steps=steps,
|
||||
)
|
||||
await self.storage.add_turn_to_session(request.session_id, turn)
|
||||
|
|
@ -335,7 +326,6 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
sampling_params: SamplingParams,
|
||||
stream: bool = False,
|
||||
documents: Optional[List[Document]] = None,
|
||||
toolgroups_for_turn: Optional[List[AgentToolGroup]] = None,
|
||||
) -> AsyncGenerator:
|
||||
# Doing async generators makes downstream code much simpler and everything amenable to
|
||||
# streaming. However, it also makes things complicated here because AsyncGenerators cannot
|
||||
|
|
@ -358,7 +348,6 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
sampling_params,
|
||||
stream,
|
||||
documents,
|
||||
toolgroups_for_turn,
|
||||
):
|
||||
if isinstance(res, bool):
|
||||
return
|
||||
|
|
@ -390,14 +379,14 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
shields: List[str],
|
||||
touchpoint: str,
|
||||
) -> AsyncGenerator:
|
||||
with tracing.span("run_shields") as span:
|
||||
async with tracing.span("run_shields") as span:
|
||||
span.set_attribute("input", [m.model_dump_json() for m in messages])
|
||||
if len(shields) == 0:
|
||||
span.set_attribute("output", "no shields")
|
||||
return
|
||||
|
||||
step_id = str(uuid.uuid4())
|
||||
shield_call_start_time = datetime.now().astimezone().isoformat()
|
||||
shield_call_start_time = datetime.now(timezone.utc).isoformat()
|
||||
try:
|
||||
yield AgentTurnResponseStreamChunk(
|
||||
event=AgentTurnResponseEvent(
|
||||
|
|
@ -421,7 +410,7 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
turn_id=turn_id,
|
||||
violation=e.violation,
|
||||
started_at=shield_call_start_time,
|
||||
completed_at=datetime.now().astimezone().isoformat(),
|
||||
completed_at=datetime.now(timezone.utc).isoformat(),
|
||||
),
|
||||
)
|
||||
)
|
||||
|
|
@ -444,7 +433,7 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
turn_id=turn_id,
|
||||
violation=None,
|
||||
started_at=shield_call_start_time,
|
||||
completed_at=datetime.now().astimezone().isoformat(),
|
||||
completed_at=datetime.now(timezone.utc).isoformat(),
|
||||
),
|
||||
)
|
||||
)
|
||||
|
|
@ -459,30 +448,19 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
sampling_params: SamplingParams,
|
||||
stream: bool = False,
|
||||
documents: Optional[List[Document]] = None,
|
||||
toolgroups_for_turn: Optional[List[AgentToolGroup]] = None,
|
||||
) -> AsyncGenerator:
|
||||
# TODO: simplify all of this code, it can be simpler
|
||||
toolgroup_args = {}
|
||||
toolgroups = set()
|
||||
for toolgroup in self.agent_config.toolgroups + (toolgroups_for_turn or []):
|
||||
if isinstance(toolgroup, AgentToolGroupWithArgs):
|
||||
tool_group_name, tool_name = self._parse_toolgroup_name(toolgroup.name)
|
||||
toolgroups.add(tool_group_name)
|
||||
toolgroup_args[tool_group_name] = toolgroup.args
|
||||
else:
|
||||
toolgroups.add(toolgroup)
|
||||
|
||||
tool_defs, tool_to_group = await self._get_tool_defs(toolgroups_for_turn)
|
||||
if documents:
|
||||
await self.handle_documents(session_id, documents, input_messages, tool_defs)
|
||||
await self.handle_documents(session_id, documents, input_messages)
|
||||
|
||||
session_info = await self.storage.get_session_info(session_id)
|
||||
# if the session has a memory bank id, let the memory tool use it
|
||||
if session_info and session_info.vector_db_id:
|
||||
if RAG_TOOL_GROUP not in toolgroup_args:
|
||||
toolgroup_args[RAG_TOOL_GROUP] = {"vector_db_ids": [session_info.vector_db_id]}
|
||||
else:
|
||||
toolgroup_args[RAG_TOOL_GROUP]["vector_db_ids"].append(session_info.vector_db_id)
|
||||
for tool_name in self.tool_name_to_args.keys():
|
||||
if tool_name == MEMORY_QUERY_TOOL:
|
||||
if "vector_db_ids" not in self.tool_name_to_args[tool_name]:
|
||||
self.tool_name_to_args[tool_name]["vector_db_ids"] = [session_info.vector_db_id]
|
||||
else:
|
||||
self.tool_name_to_args[tool_name]["vector_db_ids"].append(session_info.vector_db_id)
|
||||
|
||||
output_attachments = []
|
||||
|
||||
|
|
@ -494,7 +472,7 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
client_tools[tool.name] = tool
|
||||
while True:
|
||||
step_id = str(uuid.uuid4())
|
||||
inference_start_time = datetime.now().astimezone().isoformat()
|
||||
inference_start_time = datetime.now(timezone.utc).isoformat()
|
||||
yield AgentTurnResponseStreamChunk(
|
||||
event=AgentTurnResponseEvent(
|
||||
payload=AgentTurnResponseStepStartPayload(
|
||||
|
|
@ -508,11 +486,11 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
content = ""
|
||||
stop_reason = None
|
||||
|
||||
with tracing.span("inference") as span:
|
||||
async with tracing.span("inference") as span:
|
||||
async for chunk in await self.inference_api.chat_completion(
|
||||
self.agent_config.model,
|
||||
input_messages,
|
||||
tools=tool_defs,
|
||||
tools=self.tool_defs,
|
||||
tool_prompt_format=self.agent_config.tool_config.tool_prompt_format,
|
||||
response_format=self.agent_config.response_format,
|
||||
stream=True,
|
||||
|
|
@ -604,7 +582,7 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
turn_id=turn_id,
|
||||
model_response=copy.deepcopy(message),
|
||||
started_at=inference_start_time,
|
||||
completed_at=datetime.now().astimezone().isoformat(),
|
||||
completed_at=datetime.now(timezone.utc).isoformat(),
|
||||
),
|
||||
)
|
||||
)
|
||||
|
|
@ -636,125 +614,143 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
logger.debug(f"completion message with EOM (iter: {n_iter}): {str(message)}")
|
||||
input_messages = input_messages + [message]
|
||||
else:
|
||||
logger.debug(f"completion message (iter: {n_iter}) from the model: {str(message)}")
|
||||
# 1. Start the tool execution step and progress
|
||||
step_id = str(uuid.uuid4())
|
||||
yield AgentTurnResponseStreamChunk(
|
||||
event=AgentTurnResponseEvent(
|
||||
payload=AgentTurnResponseStepStartPayload(
|
||||
step_type=StepType.tool_execution.value,
|
||||
step_id=step_id,
|
||||
)
|
||||
)
|
||||
)
|
||||
tool_call = message.tool_calls[0]
|
||||
yield AgentTurnResponseStreamChunk(
|
||||
event=AgentTurnResponseEvent(
|
||||
payload=AgentTurnResponseStepProgressPayload(
|
||||
step_type=StepType.tool_execution.value,
|
||||
step_id=step_id,
|
||||
tool_call=tool_call,
|
||||
delta=ToolCallDelta(
|
||||
parse_status=ToolCallParseStatus.in_progress,
|
||||
tool_call=tool_call,
|
||||
),
|
||||
)
|
||||
)
|
||||
)
|
||||
input_messages = input_messages + [message]
|
||||
|
||||
# If tool is a client tool, yield CompletionMessage and return
|
||||
if tool_call.tool_name in client_tools:
|
||||
# NOTE: mark end_of_message to indicate to client that it may
|
||||
# call the tool and continue the conversation with the tool's response.
|
||||
message.stop_reason = StopReason.end_of_message
|
||||
# Process tool calls in the message
|
||||
client_tool_calls = []
|
||||
non_client_tool_calls = []
|
||||
|
||||
# Separate client and non-client tool calls
|
||||
for tool_call in message.tool_calls:
|
||||
if tool_call.tool_name in client_tools:
|
||||
client_tool_calls.append(tool_call)
|
||||
else:
|
||||
non_client_tool_calls.append(tool_call)
|
||||
|
||||
# Process non-client tool calls first
|
||||
for tool_call in non_client_tool_calls:
|
||||
step_id = str(uuid.uuid4())
|
||||
yield AgentTurnResponseStreamChunk(
|
||||
event=AgentTurnResponseEvent(
|
||||
payload=AgentTurnResponseStepStartPayload(
|
||||
step_type=StepType.tool_execution.value,
|
||||
step_id=step_id,
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
yield AgentTurnResponseStreamChunk(
|
||||
event=AgentTurnResponseEvent(
|
||||
payload=AgentTurnResponseStepProgressPayload(
|
||||
step_type=StepType.tool_execution.value,
|
||||
step_id=step_id,
|
||||
delta=ToolCallDelta(
|
||||
parse_status=ToolCallParseStatus.in_progress,
|
||||
tool_call=tool_call,
|
||||
),
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
# Execute the tool call
|
||||
async with tracing.span(
|
||||
"tool_execution",
|
||||
{
|
||||
"tool_name": tool_call.tool_name,
|
||||
"input": message.model_dump_json(),
|
||||
},
|
||||
) as span:
|
||||
tool_execution_start_time = datetime.now(timezone.utc).isoformat()
|
||||
tool_result = await self.execute_tool_call_maybe(
|
||||
session_id,
|
||||
tool_call,
|
||||
)
|
||||
if tool_result.content is None:
|
||||
raise ValueError(
|
||||
f"Tool call result (id: {tool_call.call_id}, name: {tool_call.tool_name}) does not have any content"
|
||||
)
|
||||
result_message = ToolResponseMessage(
|
||||
call_id=tool_call.call_id,
|
||||
content=tool_result.content,
|
||||
)
|
||||
span.set_attribute("output", result_message.model_dump_json())
|
||||
|
||||
# Store tool execution step
|
||||
tool_execution_step = ToolExecutionStep(
|
||||
step_id=step_id,
|
||||
turn_id=turn_id,
|
||||
tool_calls=[tool_call],
|
||||
tool_responses=[
|
||||
ToolResponse(
|
||||
call_id=tool_call.call_id,
|
||||
tool_name=tool_call.tool_name,
|
||||
content=tool_result.content,
|
||||
metadata=tool_result.metadata,
|
||||
)
|
||||
],
|
||||
started_at=tool_execution_start_time,
|
||||
completed_at=datetime.now(timezone.utc).isoformat(),
|
||||
)
|
||||
|
||||
# Yield the step completion event
|
||||
yield AgentTurnResponseStreamChunk(
|
||||
event=AgentTurnResponseEvent(
|
||||
payload=AgentTurnResponseStepCompletePayload(
|
||||
step_type=StepType.tool_execution.value,
|
||||
step_id=step_id,
|
||||
step_details=tool_execution_step,
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
# Add the result message to input_messages for the next iteration
|
||||
input_messages.append(result_message)
|
||||
|
||||
# TODO: add tool-input touchpoint and a "start" event for this step also
|
||||
# but that needs a lot more refactoring of Tool code potentially
|
||||
if (type(result_message.content) is str) and (
|
||||
out_attachment := _interpret_content_as_attachment(result_message.content)
|
||||
):
|
||||
# NOTE: when we push this message back to the model, the model may ignore the
|
||||
# attached file path etc. since the model is trained to only provide a user message
|
||||
# with the summary. We keep all generated attachments and then attach them to final message
|
||||
output_attachments.append(out_attachment)
|
||||
|
||||
# If there are client tool calls, yield a message with only those tool calls
|
||||
if client_tool_calls:
|
||||
await self.storage.set_in_progress_tool_call_step(
|
||||
session_id,
|
||||
turn_id,
|
||||
ToolExecutionStep(
|
||||
step_id=step_id,
|
||||
turn_id=turn_id,
|
||||
tool_calls=[tool_call],
|
||||
tool_calls=client_tool_calls,
|
||||
tool_responses=[],
|
||||
started_at=datetime.now().astimezone().isoformat(),
|
||||
started_at=datetime.now(timezone.utc).isoformat(),
|
||||
),
|
||||
)
|
||||
yield message
|
||||
|
||||
# Create a copy of the message with only client tool calls
|
||||
client_message = message.model_copy(deep=True)
|
||||
client_message.tool_calls = client_tool_calls
|
||||
# NOTE: mark end_of_message to indicate to client that it may
|
||||
# call the tool and continue the conversation with the tool's response.
|
||||
client_message.stop_reason = StopReason.end_of_message
|
||||
|
||||
# Yield the message with client tool calls
|
||||
yield client_message
|
||||
return
|
||||
|
||||
# If tool is a builtin server tool, execute it
|
||||
tool_name = tool_call.tool_name
|
||||
if isinstance(tool_name, BuiltinTool):
|
||||
tool_name = tool_name.value
|
||||
with tracing.span(
|
||||
"tool_execution",
|
||||
{
|
||||
"tool_name": tool_name,
|
||||
"input": message.model_dump_json(),
|
||||
},
|
||||
) as span:
|
||||
tool_execution_start_time = datetime.now().astimezone().isoformat()
|
||||
tool_call = message.tool_calls[0]
|
||||
tool_result = await execute_tool_call_maybe(
|
||||
self.tool_runtime_api,
|
||||
session_id,
|
||||
tool_call,
|
||||
toolgroup_args,
|
||||
tool_to_group,
|
||||
)
|
||||
if tool_result.content is None:
|
||||
raise ValueError(
|
||||
f"Tool call result (id: {tool_call.call_id}, name: {tool_call.tool_name}) does not have any content"
|
||||
)
|
||||
result_messages = [
|
||||
ToolResponseMessage(
|
||||
call_id=tool_call.call_id,
|
||||
tool_name=tool_call.tool_name,
|
||||
content=tool_result.content,
|
||||
)
|
||||
]
|
||||
assert len(result_messages) == 1, "Currently not supporting multiple messages"
|
||||
result_message = result_messages[0]
|
||||
span.set_attribute("output", result_message.model_dump_json())
|
||||
async def _initialize_tools(
|
||||
self,
|
||||
toolgroups_for_turn: Optional[List[AgentToolGroup]] = None,
|
||||
) -> None:
|
||||
toolgroup_to_args = {}
|
||||
for toolgroup in (self.agent_config.toolgroups or []) + (toolgroups_for_turn or []):
|
||||
if isinstance(toolgroup, AgentToolGroupWithArgs):
|
||||
tool_group_name, _ = self._parse_toolgroup_name(toolgroup.name)
|
||||
toolgroup_to_args[tool_group_name] = toolgroup.args
|
||||
|
||||
yield AgentTurnResponseStreamChunk(
|
||||
event=AgentTurnResponseEvent(
|
||||
payload=AgentTurnResponseStepCompletePayload(
|
||||
step_type=StepType.tool_execution.value,
|
||||
step_id=step_id,
|
||||
step_details=ToolExecutionStep(
|
||||
step_id=step_id,
|
||||
turn_id=turn_id,
|
||||
tool_calls=[tool_call],
|
||||
tool_responses=[
|
||||
ToolResponse(
|
||||
call_id=result_message.call_id,
|
||||
tool_name=result_message.tool_name,
|
||||
content=result_message.content,
|
||||
metadata=tool_result.metadata,
|
||||
)
|
||||
],
|
||||
started_at=tool_execution_start_time,
|
||||
completed_at=datetime.now().astimezone().isoformat(),
|
||||
),
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
# TODO: add tool-input touchpoint and a "start" event for this step also
|
||||
# but that needs a lot more refactoring of Tool code potentially
|
||||
if (type(result_message.content) is str) and (
|
||||
out_attachment := _interpret_content_as_attachment(result_message.content)
|
||||
):
|
||||
# NOTE: when we push this message back to the model, the model may ignore the
|
||||
# attached file path etc. since the model is trained to only provide a user message
|
||||
# with the summary. We keep all generated attachments and then attach them to final message
|
||||
output_attachments.append(out_attachment)
|
||||
|
||||
input_messages = input_messages + [message, result_message]
|
||||
|
||||
async def _get_tool_defs(
|
||||
self, toolgroups_for_turn: Optional[List[AgentToolGroup]] = None
|
||||
) -> Tuple[List[ToolDefinition], Dict[str, str]]:
|
||||
# Determine which tools to include
|
||||
tool_groups_to_include = toolgroups_for_turn or self.agent_config.toolgroups or []
|
||||
agent_config_toolgroups = []
|
||||
|
|
@ -763,8 +759,10 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
if name not in agent_config_toolgroups:
|
||||
agent_config_toolgroups.append(name)
|
||||
|
||||
toolgroup_to_args = toolgroup_to_args or {}
|
||||
|
||||
tool_name_to_def = {}
|
||||
tool_to_group = {}
|
||||
tool_name_to_args = {}
|
||||
|
||||
for tool_def in self.agent_config.client_tools:
|
||||
if tool_name_to_def.get(tool_def.name, None):
|
||||
|
|
@ -782,53 +780,38 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
for param in tool_def.parameters
|
||||
},
|
||||
)
|
||||
tool_to_group[tool_def.name] = "__client_tools__"
|
||||
for toolgroup_name_with_maybe_tool_name in agent_config_toolgroups:
|
||||
toolgroup_name, tool_name = self._parse_toolgroup_name(toolgroup_name_with_maybe_tool_name)
|
||||
toolgroup_name, input_tool_name = self._parse_toolgroup_name(toolgroup_name_with_maybe_tool_name)
|
||||
tools = await self.tool_groups_api.list_tools(toolgroup_id=toolgroup_name)
|
||||
if not tools.data:
|
||||
available_tool_groups = ", ".join(
|
||||
[t.identifier for t in (await self.tool_groups_api.list_tool_groups()).data]
|
||||
)
|
||||
raise ValueError(f"Toolgroup {toolgroup_name} not found, available toolgroups: {available_tool_groups}")
|
||||
if tool_name is not None and not any(tool.identifier == tool_name for tool in tools.data):
|
||||
if input_tool_name is not None and not any(tool.identifier == input_tool_name for tool in tools.data):
|
||||
raise ValueError(
|
||||
f"Tool {tool_name} not found in toolgroup {toolgroup_name}. Available tools: {', '.join([tool.identifier for tool in tools.data])}"
|
||||
f"Tool {input_tool_name} not found in toolgroup {toolgroup_name}. Available tools: {', '.join([tool.identifier for tool in tools.data])}"
|
||||
)
|
||||
|
||||
for tool_def in tools.data:
|
||||
if toolgroup_name.startswith("builtin") and toolgroup_name != RAG_TOOL_GROUP:
|
||||
tool_name = tool_def.identifier
|
||||
built_in_type = BuiltinTool.brave_search
|
||||
if tool_name == "web_search":
|
||||
built_in_type = BuiltinTool.brave_search
|
||||
identifier: str | BuiltinTool | None = tool_def.identifier
|
||||
if identifier == "web_search":
|
||||
identifier = BuiltinTool.brave_search
|
||||
else:
|
||||
built_in_type = BuiltinTool(tool_name)
|
||||
identifier = BuiltinTool(identifier)
|
||||
else:
|
||||
# add if tool_name is unspecified or the tool_def identifier is the same as the tool_name
|
||||
if input_tool_name in (None, tool_def.identifier):
|
||||
identifier = tool_def.identifier
|
||||
else:
|
||||
identifier = None
|
||||
|
||||
if tool_name_to_def.get(built_in_type, None):
|
||||
raise ValueError(f"Tool {built_in_type} already exists")
|
||||
|
||||
tool_name_to_def[built_in_type] = ToolDefinition(
|
||||
tool_name=built_in_type,
|
||||
description=tool_def.description,
|
||||
parameters={
|
||||
param.name: ToolParamDefinition(
|
||||
param_type=param.parameter_type,
|
||||
description=param.description,
|
||||
required=param.required,
|
||||
default=param.default,
|
||||
)
|
||||
for param in tool_def.parameters
|
||||
},
|
||||
)
|
||||
tool_to_group[built_in_type] = tool_def.toolgroup_id
|
||||
continue
|
||||
|
||||
if tool_name_to_def.get(tool_def.identifier, None):
|
||||
raise ValueError(f"Tool {tool_def.identifier} already exists")
|
||||
if tool_name in (None, tool_def.identifier):
|
||||
if tool_name_to_def.get(identifier, None):
|
||||
raise ValueError(f"Tool {identifier} already exists")
|
||||
if identifier:
|
||||
tool_name_to_def[tool_def.identifier] = ToolDefinition(
|
||||
tool_name=tool_def.identifier,
|
||||
tool_name=identifier,
|
||||
description=tool_def.description,
|
||||
parameters={
|
||||
param.name: ToolParamDefinition(
|
||||
|
|
@ -840,9 +823,9 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
for param in tool_def.parameters
|
||||
},
|
||||
)
|
||||
tool_to_group[tool_def.identifier] = tool_def.toolgroup_id
|
||||
tool_name_to_args[tool_def.identifier] = toolgroup_to_args.get(toolgroup_name, {})
|
||||
|
||||
return list(tool_name_to_def.values()), tool_to_group
|
||||
self.tool_defs, self.tool_name_to_args = list(tool_name_to_def.values()), tool_name_to_args
|
||||
|
||||
def _parse_toolgroup_name(self, toolgroup_name_with_maybe_tool_name: str) -> tuple[str, Optional[str]]:
|
||||
"""Parse a toolgroup name into its components.
|
||||
|
|
@ -861,15 +844,46 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
tool_group, tool_name = split_names[0], None
|
||||
return tool_group, tool_name
|
||||
|
||||
async def execute_tool_call_maybe(
|
||||
self,
|
||||
session_id: str,
|
||||
tool_call: ToolCall,
|
||||
) -> ToolInvocationResult:
|
||||
tool_name = tool_call.tool_name
|
||||
registered_tool_names = [tool_def.tool_name for tool_def in self.tool_defs]
|
||||
if tool_name not in registered_tool_names:
|
||||
raise ValueError(
|
||||
f"Tool {tool_name} not found in provided tools, registered tools: {', '.join([str(x) for x in registered_tool_names])}"
|
||||
)
|
||||
if isinstance(tool_name, BuiltinTool):
|
||||
if tool_name == BuiltinTool.brave_search:
|
||||
tool_name_str = WEB_SEARCH_TOOL
|
||||
else:
|
||||
tool_name_str = tool_name.value
|
||||
else:
|
||||
tool_name_str = tool_name
|
||||
|
||||
logger.info(f"executing tool call: {tool_name_str} with args: {tool_call.arguments}")
|
||||
result = await self.tool_runtime_api.invoke_tool(
|
||||
tool_name=tool_name_str,
|
||||
kwargs={
|
||||
"session_id": session_id,
|
||||
# get the arguments generated by the model and augment with toolgroup arg overrides for the agent
|
||||
**tool_call.arguments,
|
||||
**self.tool_name_to_args.get(tool_name_str, {}),
|
||||
},
|
||||
)
|
||||
logger.debug(f"tool call {tool_name_str} completed with result: {result}")
|
||||
return result
|
||||
|
||||
async def handle_documents(
|
||||
self,
|
||||
session_id: str,
|
||||
documents: List[Document],
|
||||
input_messages: List[Message],
|
||||
tool_defs: Dict[str, ToolDefinition],
|
||||
) -> None:
|
||||
memory_tool = any(tool_def.tool_name == MEMORY_QUERY_TOOL for tool_def in tool_defs)
|
||||
code_interpreter_tool = any(tool_def.tool_name == BuiltinTool.code_interpreter for tool_def in tool_defs)
|
||||
memory_tool = any(tool_def.tool_name == MEMORY_QUERY_TOOL for tool_def in self.tool_defs)
|
||||
code_interpreter_tool = any(tool_def.tool_name == BuiltinTool.code_interpreter for tool_def in self.tool_defs)
|
||||
content_items = []
|
||||
url_items = []
|
||||
pattern = re.compile("^(https?://|file://|data:)")
|
||||
|
|
@ -892,16 +906,14 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
if memory_tool and code_interpreter_tool:
|
||||
# if both memory and code_interpreter are available, we download the URLs
|
||||
# and attach the data to the last message.
|
||||
msg = await attachment_message(self.tempdir, url_items)
|
||||
input_messages.append(msg)
|
||||
await attachment_message(self.tempdir, url_items, input_messages[-1])
|
||||
# Since memory is present, add all the data to the memory bank
|
||||
await self.add_to_session_vector_db(session_id, documents)
|
||||
elif code_interpreter_tool:
|
||||
# if only code_interpreter is available, we download the URLs to a tempdir
|
||||
# and attach the path to them as a message to inference with the
|
||||
# assumption that the model invokes the code_interpreter tool with the path
|
||||
msg = await attachment_message(self.tempdir, url_items)
|
||||
input_messages.append(msg)
|
||||
await attachment_message(self.tempdir, url_items, input_messages[-1])
|
||||
elif memory_tool:
|
||||
# if only memory is available, we load the data from the URLs and content items to the memory bank
|
||||
await self.add_to_session_vector_db(session_id, documents)
|
||||
|
|
@ -968,8 +980,8 @@ async def load_data_from_urls(urls: List[URL]) -> List[str]:
|
|||
return data
|
||||
|
||||
|
||||
async def attachment_message(tempdir: str, urls: List[URL]) -> ToolResponseMessage:
|
||||
content = []
|
||||
async def attachment_message(tempdir: str, urls: List[URL], message: UserMessage) -> None:
|
||||
contents = []
|
||||
|
||||
for url in urls:
|
||||
uri = url.uri
|
||||
|
|
@ -989,48 +1001,19 @@ async def attachment_message(tempdir: str, urls: List[URL]) -> ToolResponseMessa
|
|||
else:
|
||||
raise ValueError(f"Unsupported URL {url}")
|
||||
|
||||
content.append(
|
||||
contents.append(
|
||||
TextContentItem(
|
||||
text=f'# User provided a file accessible to you at "{filepath}"\nYou can use code_interpreter to load and inspect it.'
|
||||
)
|
||||
)
|
||||
|
||||
return ToolResponseMessage(
|
||||
call_id="",
|
||||
tool_name=BuiltinTool.code_interpreter,
|
||||
content=content,
|
||||
)
|
||||
|
||||
|
||||
async def execute_tool_call_maybe(
|
||||
tool_runtime_api: ToolRuntime,
|
||||
session_id: str,
|
||||
tool_call: ToolCall,
|
||||
toolgroup_args: Dict[str, Dict[str, Any]],
|
||||
tool_to_group: Dict[str, str],
|
||||
) -> ToolInvocationResult:
|
||||
name = tool_call.tool_name
|
||||
group_name = tool_to_group.get(name, None)
|
||||
if group_name is None:
|
||||
raise ValueError(f"Tool {name} not found in any tool group")
|
||||
if isinstance(name, BuiltinTool):
|
||||
if name == BuiltinTool.brave_search:
|
||||
name = WEB_SEARCH_TOOL
|
||||
if isinstance(message.content, list):
|
||||
message.content.extend(contents)
|
||||
else:
|
||||
if isinstance(message.content, str):
|
||||
message.content = [TextContentItem(text=message.content)] + contents
|
||||
else:
|
||||
name = name.value
|
||||
|
||||
logger.info(f"executing tool call: {name} with args: {tool_call.arguments}")
|
||||
result = await tool_runtime_api.invoke_tool(
|
||||
tool_name=name,
|
||||
kwargs={
|
||||
"session_id": session_id,
|
||||
# get the arguments generated by the model and augment with toolgroup arg overrides for the agent
|
||||
**tool_call.arguments,
|
||||
**toolgroup_args.get(group_name, {}),
|
||||
},
|
||||
)
|
||||
logger.info(f"tool call {name} completed with result: {result}")
|
||||
return result
|
||||
message.content = [message.content] + contents
|
||||
|
||||
|
||||
def _interpret_content_as_attachment(
|
||||
|
|
|
|||
|
|
@ -12,6 +12,7 @@ import uuid
|
|||
from typing import AsyncGenerator, List, Optional, Union
|
||||
|
||||
from llama_stack.apis.agents import (
|
||||
Agent,
|
||||
AgentConfig,
|
||||
AgentCreateResponse,
|
||||
Agents,
|
||||
|
|
@ -21,6 +22,8 @@ from llama_stack.apis.agents import (
|
|||
AgentTurnCreateRequest,
|
||||
AgentTurnResumeRequest,
|
||||
Document,
|
||||
ListAgentSessionsResponse,
|
||||
ListAgentsResponse,
|
||||
Session,
|
||||
Turn,
|
||||
)
|
||||
|
|
@ -84,7 +87,7 @@ class MetaReferenceAgentsImpl(Agents):
|
|||
agent_id=agent_id,
|
||||
)
|
||||
|
||||
async def get_agent(self, agent_id: str) -> ChatAgent:
|
||||
async def _get_agent_impl(self, agent_id: str) -> ChatAgent:
|
||||
agent_config = await self.persistence_store.get(
|
||||
key=f"agent:{agent_id}",
|
||||
)
|
||||
|
|
@ -120,7 +123,7 @@ class MetaReferenceAgentsImpl(Agents):
|
|||
agent_id: str,
|
||||
session_name: str,
|
||||
) -> AgentSessionCreateResponse:
|
||||
agent = await self.get_agent(agent_id)
|
||||
agent = await self._get_agent_impl(agent_id)
|
||||
|
||||
session_id = await agent.create_session(session_name)
|
||||
return AgentSessionCreateResponse(
|
||||
|
|
@ -160,7 +163,7 @@ class MetaReferenceAgentsImpl(Agents):
|
|||
self,
|
||||
request: AgentTurnCreateRequest,
|
||||
) -> AsyncGenerator:
|
||||
agent = await self.get_agent(request.agent_id)
|
||||
agent = await self._get_agent_impl(request.agent_id)
|
||||
async for event in agent.create_and_execute_turn(request):
|
||||
yield event
|
||||
|
||||
|
|
@ -169,7 +172,7 @@ class MetaReferenceAgentsImpl(Agents):
|
|||
agent_id: str,
|
||||
session_id: str,
|
||||
turn_id: str,
|
||||
tool_responses: Union[List[ToolResponse], List[ToolResponseMessage]],
|
||||
tool_responses: List[ToolResponse],
|
||||
stream: Optional[bool] = False,
|
||||
) -> AsyncGenerator:
|
||||
request = AgentTurnResumeRequest(
|
||||
|
|
@ -188,12 +191,12 @@ class MetaReferenceAgentsImpl(Agents):
|
|||
self,
|
||||
request: AgentTurnResumeRequest,
|
||||
) -> AsyncGenerator:
|
||||
agent = await self.get_agent(request.agent_id)
|
||||
agent = await self._get_agent_impl(request.agent_id)
|
||||
async for event in agent.resume_turn(request):
|
||||
yield event
|
||||
|
||||
async def get_agents_turn(self, agent_id: str, session_id: str, turn_id: str) -> Turn:
|
||||
agent = await self.get_agent(agent_id)
|
||||
agent = await self._get_agent_impl(agent_id)
|
||||
turn = await agent.storage.get_session_turn(session_id, turn_id)
|
||||
return turn
|
||||
|
||||
|
|
@ -210,7 +213,7 @@ class MetaReferenceAgentsImpl(Agents):
|
|||
session_id: str,
|
||||
turn_ids: Optional[List[str]] = None,
|
||||
) -> Session:
|
||||
agent = await self.get_agent(agent_id)
|
||||
agent = await self._get_agent_impl(agent_id)
|
||||
session_info = await agent.storage.get_session_info(session_id)
|
||||
if session_info is None:
|
||||
raise ValueError(f"Session {session_id} not found")
|
||||
|
|
@ -232,3 +235,15 @@ class MetaReferenceAgentsImpl(Agents):
|
|||
|
||||
async def shutdown(self) -> None:
|
||||
pass
|
||||
|
||||
async def list_agents(self) -> ListAgentsResponse:
|
||||
pass
|
||||
|
||||
async def get_agent(self, agent_id: str) -> Agent:
|
||||
pass
|
||||
|
||||
async def list_agent_sessions(
|
||||
self,
|
||||
agent_id: str,
|
||||
) -> ListAgentSessionsResponse:
|
||||
pass
|
||||
|
|
|
|||
|
|
@ -7,7 +7,7 @@
|
|||
import json
|
||||
import logging
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
from datetime import datetime, timezone
|
||||
from typing import List, Optional
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
|
@ -36,7 +36,7 @@ class AgentPersistence:
|
|||
session_info = AgentSessionInfo(
|
||||
session_id=session_id,
|
||||
session_name=name,
|
||||
started_at=datetime.now(),
|
||||
started_at=datetime.now(timezone.utc),
|
||||
)
|
||||
await self.kvstore.set(
|
||||
key=f"session:{self.agent_id}:{session_id}",
|
||||
|
|
|
|||
|
|
@ -10,6 +10,7 @@ from typing import List
|
|||
|
||||
from llama_stack.apis.inference import Message
|
||||
from llama_stack.apis.safety import Safety, SafetyViolation, ViolationLevel
|
||||
from llama_stack.providers.utils.telemetry import tracing
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
|
|
@ -32,15 +33,14 @@ class ShieldRunnerMixin:
|
|||
self.output_shields = output_shields
|
||||
|
||||
async def run_multiple_shields(self, messages: List[Message], identifiers: List[str]) -> None:
|
||||
responses = await asyncio.gather(
|
||||
*[
|
||||
self.safety_api.run_shield(
|
||||
async def run_shield_with_span(identifier: str):
|
||||
async with tracing.span(f"run_shield_{identifier}"):
|
||||
return await self.safety_api.run_shield(
|
||||
shield_id=identifier,
|
||||
messages=messages,
|
||||
)
|
||||
for identifier in identifiers
|
||||
]
|
||||
)
|
||||
|
||||
responses = await asyncio.gather(*[run_shield_with_span(identifier) for identifier in identifiers])
|
||||
for identifier, response in zip(identifiers, responses, strict=False):
|
||||
if not response.violation:
|
||||
continue
|
||||
|
|
|
|||
|
|
@ -4,12 +4,14 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import Any, Dict
|
||||
|
||||
from .config import LocalFSDatasetIOConfig
|
||||
|
||||
|
||||
async def get_provider_impl(
|
||||
config: LocalFSDatasetIOConfig,
|
||||
_deps,
|
||||
_deps: Dict[str, Any],
|
||||
):
|
||||
from .datasetio import LocalFSDatasetIOImpl
|
||||
|
||||
|
|
|
|||
|
|
@ -3,9 +3,10 @@
|
|||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
from typing import Any, Dict
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from llama_stack.distribution.utils.config_dirs import RUNTIME_BASE_DIR
|
||||
from llama_stack.providers.utils.kvstore.config import (
|
||||
KVStoreConfig,
|
||||
SqliteKVStoreConfig,
|
||||
|
|
@ -13,6 +14,13 @@ from llama_stack.providers.utils.kvstore.config import (
|
|||
|
||||
|
||||
class LocalFSDatasetIOConfig(BaseModel):
|
||||
kvstore: KVStoreConfig = SqliteKVStoreConfig(
|
||||
db_path=(RUNTIME_BASE_DIR / "localfs_datasetio.db").as_posix()
|
||||
) # Uses SQLite config specific to localfs storage
|
||||
kvstore: KVStoreConfig
|
||||
|
||||
@classmethod
|
||||
def sample_run_config(cls, __distro_dir__: str, **kwargs: Any) -> Dict[str, Any]:
|
||||
return {
|
||||
"kvstore": SqliteKVStoreConfig.sample_run_config(
|
||||
__distro_dir__=__distro_dir__,
|
||||
db_name="localfs_datasetio.db",
|
||||
)
|
||||
}
|
||||
|
|
|
|||
|
|
@ -3,20 +3,14 @@
|
|||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
import base64
|
||||
import os
|
||||
from abc import ABC, abstractmethod
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Dict, List, Optional
|
||||
from urllib.parse import urlparse
|
||||
|
||||
import pandas
|
||||
|
||||
from llama_stack.apis.common.content_types import URL
|
||||
from llama_stack.apis.datasetio import DatasetIO, PaginatedRowsResult
|
||||
from llama_stack.apis.datasetio import DatasetIO, IterrowsResponse
|
||||
from llama_stack.apis.datasets import Dataset
|
||||
from llama_stack.providers.datatypes import DatasetsProtocolPrivate
|
||||
from llama_stack.providers.utils.datasetio.url_utils import get_dataframe_from_url
|
||||
from llama_stack.providers.utils.datasetio.url_utils import get_dataframe_from_uri
|
||||
from llama_stack.providers.utils.kvstore import kvstore_impl
|
||||
|
||||
from .config import LocalFSDatasetIOConfig
|
||||
|
|
@ -24,30 +18,7 @@ from .config import LocalFSDatasetIOConfig
|
|||
DATASETS_PREFIX = "localfs_datasets:"
|
||||
|
||||
|
||||
class BaseDataset(ABC):
|
||||
def __init__(self, *args, **kwargs) -> None:
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
@abstractmethod
|
||||
def __len__(self) -> int:
|
||||
raise NotImplementedError()
|
||||
|
||||
@abstractmethod
|
||||
def __getitem__(self, idx):
|
||||
raise NotImplementedError()
|
||||
|
||||
@abstractmethod
|
||||
def load(self):
|
||||
raise NotImplementedError()
|
||||
|
||||
|
||||
@dataclass
|
||||
class DatasetInfo:
|
||||
dataset_def: Dataset
|
||||
dataset_impl: BaseDataset
|
||||
|
||||
|
||||
class PandasDataframeDataset(BaseDataset):
|
||||
class PandasDataframeDataset:
|
||||
def __init__(self, dataset_def: Dataset, *args, **kwargs) -> None:
|
||||
super().__init__(*args, **kwargs)
|
||||
self.dataset_def = dataset_def
|
||||
|
|
@ -64,23 +35,19 @@ class PandasDataframeDataset(BaseDataset):
|
|||
else:
|
||||
return self.df.iloc[idx].to_dict()
|
||||
|
||||
def _validate_dataset_schema(self, df) -> pandas.DataFrame:
|
||||
# note that we will drop any columns in dataset that are not in the schema
|
||||
df = df[self.dataset_def.dataset_schema.keys()]
|
||||
# check all columns in dataset schema are present
|
||||
assert len(df.columns) == len(self.dataset_def.dataset_schema)
|
||||
# TODO: type checking against column types in dataset schema
|
||||
return df
|
||||
|
||||
def load(self) -> None:
|
||||
async def load(self) -> None:
|
||||
if self.df is not None:
|
||||
return
|
||||
|
||||
df = get_dataframe_from_url(self.dataset_def.url)
|
||||
if df is None:
|
||||
raise ValueError(f"Failed to load dataset from {self.dataset_def.url}")
|
||||
if self.dataset_def.source.type == "uri":
|
||||
self.df = await get_dataframe_from_uri(self.dataset_def.source.uri)
|
||||
elif self.dataset_def.source.type == "rows":
|
||||
self.df = pandas.DataFrame(self.dataset_def.source.rows)
|
||||
else:
|
||||
raise ValueError(f"Unsupported dataset source type: {self.dataset_def.source.type}")
|
||||
|
||||
self.df = self._validate_dataset_schema(df)
|
||||
if self.df is None:
|
||||
raise ValueError(f"Failed to load dataset from {self.dataset_def.url}")
|
||||
|
||||
|
||||
class LocalFSDatasetIOImpl(DatasetIO, DatasetsProtocolPrivate):
|
||||
|
|
@ -99,95 +66,55 @@ class LocalFSDatasetIOImpl(DatasetIO, DatasetsProtocolPrivate):
|
|||
|
||||
for dataset in stored_datasets:
|
||||
dataset = Dataset.model_validate_json(dataset)
|
||||
dataset_impl = PandasDataframeDataset(dataset)
|
||||
self.dataset_infos[dataset.identifier] = DatasetInfo(
|
||||
dataset_def=dataset,
|
||||
dataset_impl=dataset_impl,
|
||||
)
|
||||
self.dataset_infos[dataset.identifier] = dataset
|
||||
|
||||
async def shutdown(self) -> None: ...
|
||||
|
||||
async def register_dataset(
|
||||
self,
|
||||
dataset: Dataset,
|
||||
dataset_def: Dataset,
|
||||
) -> None:
|
||||
# Store in kvstore
|
||||
key = f"{DATASETS_PREFIX}{dataset.identifier}"
|
||||
key = f"{DATASETS_PREFIX}{dataset_def.identifier}"
|
||||
await self.kvstore.set(
|
||||
key=key,
|
||||
value=dataset.json(),
|
||||
)
|
||||
dataset_impl = PandasDataframeDataset(dataset)
|
||||
self.dataset_infos[dataset.identifier] = DatasetInfo(
|
||||
dataset_def=dataset,
|
||||
dataset_impl=dataset_impl,
|
||||
value=dataset_def.model_dump_json(),
|
||||
)
|
||||
self.dataset_infos[dataset_def.identifier] = dataset_def
|
||||
|
||||
async def unregister_dataset(self, dataset_id: str) -> None:
|
||||
key = f"{DATASETS_PREFIX}{dataset_id}"
|
||||
await self.kvstore.delete(key=key)
|
||||
del self.dataset_infos[dataset_id]
|
||||
|
||||
async def get_rows_paginated(
|
||||
async def iterrows(
|
||||
self,
|
||||
dataset_id: str,
|
||||
rows_in_page: int,
|
||||
page_token: Optional[str] = None,
|
||||
filter_condition: Optional[str] = None,
|
||||
) -> PaginatedRowsResult:
|
||||
dataset_info = self.dataset_infos.get(dataset_id)
|
||||
dataset_info.dataset_impl.load()
|
||||
start_index: Optional[int] = None,
|
||||
limit: Optional[int] = None,
|
||||
) -> IterrowsResponse:
|
||||
dataset_def = self.dataset_infos[dataset_id]
|
||||
dataset_impl = PandasDataframeDataset(dataset_def)
|
||||
await dataset_impl.load()
|
||||
|
||||
if page_token and not page_token.isnumeric():
|
||||
raise ValueError("Invalid page_token")
|
||||
start_index = start_index or 0
|
||||
|
||||
if page_token is None or len(page_token) == 0:
|
||||
next_page_token = 0
|
||||
if limit is None or limit == -1:
|
||||
end = len(dataset_impl)
|
||||
else:
|
||||
next_page_token = int(page_token)
|
||||
end = min(start_index + limit, len(dataset_impl))
|
||||
|
||||
start = next_page_token
|
||||
if rows_in_page == -1:
|
||||
end = len(dataset_info.dataset_impl)
|
||||
else:
|
||||
end = min(start + rows_in_page, len(dataset_info.dataset_impl))
|
||||
rows = dataset_impl[start_index:end]
|
||||
|
||||
rows = dataset_info.dataset_impl[start:end]
|
||||
|
||||
return PaginatedRowsResult(
|
||||
rows=rows,
|
||||
total_count=len(rows),
|
||||
next_page_token=str(end),
|
||||
return IterrowsResponse(
|
||||
data=rows,
|
||||
next_start_index=end if end < len(dataset_impl) else None,
|
||||
)
|
||||
|
||||
async def append_rows(self, dataset_id: str, rows: List[Dict[str, Any]]) -> None:
|
||||
dataset_info = self.dataset_infos.get(dataset_id)
|
||||
if dataset_info is None:
|
||||
raise ValueError(f"Dataset with id {dataset_id} not found")
|
||||
|
||||
dataset_impl = dataset_info.dataset_impl
|
||||
dataset_impl.load()
|
||||
dataset_def = self.dataset_infos[dataset_id]
|
||||
dataset_impl = PandasDataframeDataset(dataset_def)
|
||||
await dataset_impl.load()
|
||||
|
||||
new_rows_df = pandas.DataFrame(rows)
|
||||
new_rows_df = dataset_impl._validate_dataset_schema(new_rows_df)
|
||||
dataset_impl.df = pandas.concat([dataset_impl.df, new_rows_df], ignore_index=True)
|
||||
|
||||
url = str(dataset_info.dataset_def.url)
|
||||
parsed_url = urlparse(url)
|
||||
|
||||
if parsed_url.scheme == "file" or not parsed_url.scheme:
|
||||
file_path = parsed_url.path
|
||||
os.makedirs(os.path.dirname(file_path), exist_ok=True)
|
||||
dataset_impl.df.to_csv(file_path, index=False)
|
||||
elif parsed_url.scheme == "data":
|
||||
# For data URLs, we need to update the base64-encoded content
|
||||
if not parsed_url.path.startswith("text/csv;base64,"):
|
||||
raise ValueError("Data URL must be a base64-encoded CSV")
|
||||
|
||||
csv_buffer = dataset_impl.df.to_csv(index=False)
|
||||
base64_content = base64.b64encode(csv_buffer.encode("utf-8")).decode("utf-8")
|
||||
dataset_info.dataset_def.url = URL(uri=f"data:text/csv;base64,{base64_content}")
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Unsupported URL scheme: {parsed_url.scheme}. Only file:// and data: URLs are supported for writing."
|
||||
)
|
||||
|
|
|
|||
|
|
@ -3,16 +3,16 @@
|
|||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
from typing import Dict
|
||||
from typing import Any, Dict
|
||||
|
||||
from llama_stack.distribution.datatypes import Api, ProviderSpec
|
||||
from llama_stack.distribution.datatypes import Api
|
||||
|
||||
from .config import MetaReferenceEvalConfig
|
||||
|
||||
|
||||
async def get_provider_impl(
|
||||
config: MetaReferenceEvalConfig,
|
||||
deps: Dict[Api, ProviderSpec],
|
||||
deps: Dict[Api, Any],
|
||||
):
|
||||
from .eval import MetaReferenceEvalImpl
|
||||
|
||||
|
|
|
|||
|
|
@ -3,9 +3,10 @@
|
|||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
from typing import Any, Dict
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from llama_stack.distribution.utils.config_dirs import RUNTIME_BASE_DIR
|
||||
from llama_stack.providers.utils.kvstore.config import (
|
||||
KVStoreConfig,
|
||||
SqliteKVStoreConfig,
|
||||
|
|
@ -13,6 +14,13 @@ from llama_stack.providers.utils.kvstore.config import (
|
|||
|
||||
|
||||
class MetaReferenceEvalConfig(BaseModel):
|
||||
kvstore: KVStoreConfig = SqliteKVStoreConfig(
|
||||
db_path=(RUNTIME_BASE_DIR / "meta_reference_eval.db").as_posix()
|
||||
) # Uses SQLite config specific to Meta Reference Eval storage
|
||||
kvstore: KVStoreConfig
|
||||
|
||||
@classmethod
|
||||
def sample_run_config(cls, __distro_dir__: str, **kwargs: Any) -> Dict[str, Any]:
|
||||
return {
|
||||
"kvstore": SqliteKVStoreConfig.sample_run_config(
|
||||
__distro_dir__=__distro_dir__,
|
||||
db_name="meta_reference_eval.db",
|
||||
)
|
||||
}
|
||||
|
|
|
|||
|
|
@ -12,18 +12,13 @@ from llama_stack.apis.agents import Agents, StepType
|
|||
from llama_stack.apis.benchmarks import Benchmark
|
||||
from llama_stack.apis.datasetio import DatasetIO
|
||||
from llama_stack.apis.datasets import Datasets
|
||||
from llama_stack.apis.inference import Inference, UserMessage
|
||||
from llama_stack.apis.inference import Inference, SystemMessage, UserMessage
|
||||
from llama_stack.apis.scoring import Scoring
|
||||
from llama_stack.distribution.datatypes import Api
|
||||
from llama_stack.providers.datatypes import BenchmarksProtocolPrivate
|
||||
from llama_stack.providers.inline.agents.meta_reference.agent_instance import (
|
||||
MEMORY_QUERY_TOOL,
|
||||
)
|
||||
from llama_stack.providers.utils.common.data_schema_validator import (
|
||||
ColumnName,
|
||||
get_valid_schemas,
|
||||
validate_dataset_schema,
|
||||
)
|
||||
from llama_stack.providers.utils.common.data_schema_validator import ColumnName
|
||||
from llama_stack.providers.utils.kvstore import kvstore_impl
|
||||
|
||||
from .....apis.common.job_types import Job
|
||||
|
|
@ -88,15 +83,17 @@ class MetaReferenceEvalImpl(
|
|||
task_def = self.benchmarks[benchmark_id]
|
||||
dataset_id = task_def.dataset_id
|
||||
scoring_functions = task_def.scoring_functions
|
||||
dataset_def = await self.datasets_api.get_dataset(dataset_id=dataset_id)
|
||||
validate_dataset_schema(dataset_def.dataset_schema, get_valid_schemas(Api.eval.value))
|
||||
all_rows = await self.datasetio_api.get_rows_paginated(
|
||||
|
||||
# TODO (xiyan): validate dataset schema
|
||||
# dataset_def = await self.datasets_api.get_dataset(dataset_id=dataset_id)
|
||||
|
||||
all_rows = await self.datasetio_api.iterrows(
|
||||
dataset_id=dataset_id,
|
||||
rows_in_page=(-1 if benchmark_config.num_examples is None else benchmark_config.num_examples),
|
||||
limit=(-1 if benchmark_config.num_examples is None else benchmark_config.num_examples),
|
||||
)
|
||||
res = await self.evaluate_rows(
|
||||
benchmark_id=benchmark_id,
|
||||
input_rows=all_rows.rows,
|
||||
input_rows=all_rows.data,
|
||||
scoring_functions=scoring_functions,
|
||||
benchmark_config=benchmark_config,
|
||||
)
|
||||
|
|
@ -118,7 +115,7 @@ class MetaReferenceEvalImpl(
|
|||
for i, x in tqdm(enumerate(input_rows)):
|
||||
assert ColumnName.chat_completion_input.value in x, "Invalid input row"
|
||||
input_messages = json.loads(x[ColumnName.chat_completion_input.value])
|
||||
input_messages = [UserMessage(**x) for x in input_messages]
|
||||
input_messages = [UserMessage(**x) for x in input_messages if x["role"] == "user"]
|
||||
|
||||
# NOTE: only single-turn agent generation is supported. Create a new session for each input row
|
||||
session_create_response = await self.agents_api.create_agent_session(agent_id, f"session-{i}")
|
||||
|
|
@ -168,10 +165,11 @@ class MetaReferenceEvalImpl(
|
|||
generations.append({ColumnName.generated_answer.value: response.completion_message.content})
|
||||
elif ColumnName.chat_completion_input.value in x:
|
||||
chat_completion_input_json = json.loads(x[ColumnName.chat_completion_input.value])
|
||||
input_messages = [UserMessage(**x) for x in chat_completion_input_json]
|
||||
input_messages = [UserMessage(**x) for x in chat_completion_input_json if x["role"] == "user"]
|
||||
messages = []
|
||||
if candidate.system_message:
|
||||
messages.append(candidate.system_message)
|
||||
messages += [SystemMessage(**x) for x in chat_completion_input_json if x["role"] == "system"]
|
||||
messages += input_messages
|
||||
response = await self.inference_api.chat_completion(
|
||||
model_id=candidate.model,
|
||||
|
|
|
|||
|
|
@ -4,14 +4,14 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import Union
|
||||
from typing import Any, Dict, Union
|
||||
|
||||
from .config import MetaReferenceInferenceConfig, MetaReferenceQuantizedInferenceConfig
|
||||
|
||||
|
||||
async def get_provider_impl(
|
||||
config: Union[MetaReferenceInferenceConfig, MetaReferenceQuantizedInferenceConfig],
|
||||
_deps,
|
||||
_deps: Dict[str, Any],
|
||||
):
|
||||
from .inference import MetaReferenceInferenceImpl
|
||||
|
||||
|
|
|
|||
|
|
@ -10,6 +10,7 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import copy
|
||||
import json
|
||||
import logging
|
||||
import multiprocessing
|
||||
|
|
@ -213,7 +214,7 @@ def maybe_parse_message(maybe_json: Optional[str]) -> Optional[ProcessingMessage
|
|||
|
||||
def parse_message(json_str: str) -> ProcessingMessage:
|
||||
data = json.loads(json_str)
|
||||
return ProcessingMessageWrapper(**data).payload
|
||||
return copy.deepcopy(ProcessingMessageWrapper(**data).payload)
|
||||
|
||||
|
||||
def worker_process_entrypoint(
|
||||
|
|
|
|||
|
|
@ -4,6 +4,8 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import Any, Dict
|
||||
|
||||
from llama_stack.providers.inline.inference.sentence_transformers.config import (
|
||||
SentenceTransformersInferenceConfig,
|
||||
)
|
||||
|
|
@ -11,7 +13,7 @@ from llama_stack.providers.inline.inference.sentence_transformers.config import
|
|||
|
||||
async def get_provider_impl(
|
||||
config: SentenceTransformersInferenceConfig,
|
||||
_deps,
|
||||
_deps: Dict[str, Any],
|
||||
):
|
||||
from .sentence_transformers import SentenceTransformersInferenceImpl
|
||||
|
||||
|
|
|
|||
|
|
@ -4,12 +4,12 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import Any
|
||||
from typing import Any, Dict
|
||||
|
||||
from .config import VLLMConfig
|
||||
|
||||
|
||||
async def get_provider_impl(config: VLLMConfig, _deps) -> Any:
|
||||
async def get_provider_impl(config: VLLMConfig, _deps: Dict[str, Any]):
|
||||
from .vllm import VLLMInferenceImpl
|
||||
|
||||
impl = VLLMInferenceImpl(config)
|
||||
|
|
|
|||
|
|
@ -4,6 +4,8 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import Any, Dict
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from llama_stack.schema_utils import json_schema_type
|
||||
|
|
@ -40,7 +42,7 @@ class VLLMConfig(BaseModel):
|
|||
)
|
||||
|
||||
@classmethod
|
||||
def sample_run_config(cls):
|
||||
def sample_run_config(cls, **kwargs: Any) -> Dict[str, Any]:
|
||||
return {
|
||||
"tensor_parallel_size": "${env.TENSOR_PARALLEL_SIZE:1}",
|
||||
"max_tokens": "${env.MAX_TOKENS:4096}",
|
||||
|
|
|
|||
|
|
@ -582,6 +582,7 @@ class VLLMInferenceImpl(Inference, ModelsProtocolPrivate):
|
|||
tool_name=t.function.name,
|
||||
# vLLM function args come back as a string. Llama Stack expects JSON.
|
||||
arguments=json.loads(t.function.arguments),
|
||||
arguments_json=t.function.arguments,
|
||||
)
|
||||
for t in vllm_message.tool_calls
|
||||
],
|
||||
|
|
|
|||
|
|
@ -9,6 +9,9 @@
|
|||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import Any
|
||||
|
||||
from llama_stack.apis.common.type_system import (
|
||||
ChatCompletionInputType,
|
||||
DialogType,
|
||||
|
|
@ -20,7 +23,7 @@ from llama_stack.providers.utils.common.data_schema_validator import (
|
|||
validate_dataset_schema,
|
||||
)
|
||||
|
||||
EXPECTED_DATASET_SCHEMA = {
|
||||
EXPECTED_DATASET_SCHEMA: dict[str, list[dict[str, Any]]] = {
|
||||
"instruct": [
|
||||
{
|
||||
ColumnName.chat_completion_input.value: ChatCompletionInputType(),
|
||||
|
|
@ -41,6 +44,9 @@ async def validate_input_dataset_schema(
|
|||
dataset_type: str,
|
||||
) -> None:
|
||||
dataset_def = await datasets_api.get_dataset(dataset_id=dataset_id)
|
||||
if not dataset_def:
|
||||
raise ValueError(f"Dataset {dataset_id} does not exist.")
|
||||
|
||||
if not dataset_def.dataset_schema or len(dataset_def.dataset_schema) == 0:
|
||||
raise ValueError(f"Dataset {dataset_id} does not have a schema defined.")
|
||||
|
||||
|
|
|
|||
|
|
@ -4,9 +4,9 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import Dict
|
||||
from typing import Any, Dict
|
||||
|
||||
from llama_stack.distribution.datatypes import Api, ProviderSpec
|
||||
from llama_stack.distribution.datatypes import Api
|
||||
|
||||
from .config import TorchtunePostTrainingConfig
|
||||
|
||||
|
|
@ -15,7 +15,7 @@ from .config import TorchtunePostTrainingConfig
|
|||
|
||||
async def get_provider_impl(
|
||||
config: TorchtunePostTrainingConfig,
|
||||
deps: Dict[Api, ProviderSpec],
|
||||
deps: Dict[Api, Any],
|
||||
):
|
||||
from .post_training import TorchtunePostTrainingImpl
|
||||
|
||||
|
|
|
|||
|
|
@ -37,7 +37,7 @@ class TorchtuneCheckpointer:
|
|||
checkpoint_files: List[str],
|
||||
output_dir: str,
|
||||
model_type: str,
|
||||
) -> None:
|
||||
):
|
||||
# Fail fast if ``checkpoint_files`` is invalid
|
||||
# TODO: support loading more than one file
|
||||
if len(checkpoint_files) != 1:
|
||||
|
|
@ -58,7 +58,7 @@ class TorchtuneCheckpointer:
|
|||
"""
|
||||
Load Meta checkpoint from file. Currently only loading from a single file is supported.
|
||||
"""
|
||||
state_dict: Dict[str:Any] = {}
|
||||
state_dict: Dict[str, Any] = {}
|
||||
model_state_dict = safe_torch_load(self._checkpoint_path)
|
||||
if self._model_type == ModelType.LLAMA3_VISION:
|
||||
from torchtune.models.llama3_2_vision._convert_weights import (
|
||||
|
|
@ -85,10 +85,10 @@ class TorchtuneCheckpointer:
|
|||
state_dict: Dict[str, Any],
|
||||
epoch: int,
|
||||
adapter_only: bool = False,
|
||||
checkpoint_format: str = "meta",
|
||||
checkpoint_format: str | None = None,
|
||||
) -> str:
|
||||
model_file_path = Path(self._output_dir) / f"{self._model_id}-{self._training_algorithm}-{epoch}"
|
||||
if checkpoint_format == "meta":
|
||||
if checkpoint_format == "meta" or checkpoint_format is None:
|
||||
self._save_meta_format_checkpoint(model_file_path, state_dict, adapter_only)
|
||||
elif checkpoint_format == "huggingface":
|
||||
# Note: for saving hugging face format checkpoints, we only suppport saving adapter weights now
|
||||
|
|
|
|||
|
|
@ -10,7 +10,7 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import Any, Callable, Dict
|
||||
from typing import Callable, Dict
|
||||
|
||||
import torch
|
||||
from pydantic import BaseModel
|
||||
|
|
@ -25,10 +25,13 @@ from llama_stack.apis.post_training import DatasetFormat
|
|||
from llama_stack.models.llama.datatypes import Model
|
||||
from llama_stack.models.llama.sku_list import resolve_model
|
||||
|
||||
BuildLoraModelCallable = Callable[..., torch.nn.Module]
|
||||
BuildTokenizerCallable = Callable[..., Llama3Tokenizer]
|
||||
|
||||
|
||||
class ModelConfig(BaseModel):
|
||||
model_definition: Any
|
||||
tokenizer_type: Any
|
||||
model_definition: BuildLoraModelCallable
|
||||
tokenizer_type: BuildTokenizerCallable
|
||||
checkpoint_type: str
|
||||
|
||||
|
||||
|
|
@ -51,10 +54,6 @@ DATA_FORMATS: Dict[str, Transform] = {
|
|||
}
|
||||
|
||||
|
||||
BuildLoraModelCallable = Callable[..., torch.nn.Module]
|
||||
BuildTokenizerCallable = Callable[..., Llama3Tokenizer]
|
||||
|
||||
|
||||
def _validate_model_id(model_id: str) -> Model:
|
||||
model = resolve_model(model_id)
|
||||
if model is None or model.core_model_id.value not in MODEL_CONFIGS:
|
||||
|
|
|
|||
|
|
@ -4,7 +4,7 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import Literal, Optional
|
||||
from typing import Any, Dict, Literal, Optional
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
|
@ -12,3 +12,9 @@ from pydantic import BaseModel
|
|||
class TorchtunePostTrainingConfig(BaseModel):
|
||||
torch_seed: Optional[int] = None
|
||||
checkpoint_format: Optional[Literal["meta", "huggingface"]] = "meta"
|
||||
|
||||
@classmethod
|
||||
def sample_run_config(cls, __distro_dir__: str, **kwargs: Any) -> Dict[str, Any]:
|
||||
return {
|
||||
"checkpoint_format": "meta",
|
||||
}
|
||||
|
|
|
|||
|
|
@ -55,7 +55,7 @@ class SFTDataset(Dataset):
|
|||
if "messages" in transformed_sample:
|
||||
validate_messages(transformed_sample["messages"])
|
||||
|
||||
tokenized_dict = self._model_transform(transformed_sample)
|
||||
tokenized_dict: dict[str, Any] = self._model_transform(transformed_sample)
|
||||
|
||||
if not ("tokens" in tokenized_dict and "mask" in tokenized_dict):
|
||||
keys_str = ", ".join(tokenized_dict.keys())
|
||||
|
|
|
|||
|
|
@ -3,7 +3,7 @@
|
|||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
from datetime import datetime
|
||||
from datetime import datetime, timezone
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
from llama_stack.apis.datasetio import DatasetIO
|
||||
|
|
@ -43,6 +43,9 @@ class TorchtunePostTrainingImpl:
|
|||
self.jobs = {}
|
||||
self.checkpoints_dict = {}
|
||||
|
||||
async def shutdown(self):
|
||||
pass
|
||||
|
||||
async def supervised_fine_tune(
|
||||
self,
|
||||
job_uuid: str,
|
||||
|
|
@ -61,7 +64,7 @@ class TorchtunePostTrainingImpl:
|
|||
job_status_response = PostTrainingJobStatusResponse(
|
||||
job_uuid=job_uuid,
|
||||
status=JobStatus.scheduled,
|
||||
scheduled_at=datetime.now(),
|
||||
scheduled_at=datetime.now(timezone.utc),
|
||||
)
|
||||
self.jobs[job_uuid] = job_status_response
|
||||
|
||||
|
|
@ -81,7 +84,7 @@ class TorchtunePostTrainingImpl:
|
|||
)
|
||||
|
||||
job_status_response.status = JobStatus.in_progress
|
||||
job_status_response.started_at = datetime.now()
|
||||
job_status_response.started_at = datetime.now(timezone.utc)
|
||||
|
||||
await recipe.setup()
|
||||
resources_allocated, checkpoints = await recipe.train()
|
||||
|
|
@ -90,7 +93,7 @@ class TorchtunePostTrainingImpl:
|
|||
job_status_response.resources_allocated = resources_allocated
|
||||
job_status_response.checkpoints = checkpoints
|
||||
job_status_response.status = JobStatus.completed
|
||||
job_status_response.completed_at = datetime.now()
|
||||
job_status_response.completed_at = datetime.now(timezone.utc)
|
||||
|
||||
except Exception:
|
||||
job_status_response.status = JobStatus.failed
|
||||
|
|
|
|||
|
|
@ -8,7 +8,7 @@ import gc
|
|||
import logging
|
||||
import os
|
||||
import time
|
||||
from datetime import datetime
|
||||
from datetime import datetime, timezone
|
||||
from functools import partial
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
|
|
@ -37,10 +37,10 @@ from llama_stack.apis.common.training_types import PostTrainingMetric
|
|||
from llama_stack.apis.datasetio import DatasetIO
|
||||
from llama_stack.apis.datasets import Datasets
|
||||
from llama_stack.apis.post_training import (
|
||||
AlgorithmConfig,
|
||||
Checkpoint,
|
||||
LoraFinetuningConfig,
|
||||
OptimizerConfig,
|
||||
QATFinetuningConfig,
|
||||
TrainingConfig,
|
||||
)
|
||||
from llama_stack.distribution.utils.config_dirs import DEFAULT_CHECKPOINT_DIR
|
||||
|
|
@ -73,6 +73,9 @@ class LoraFinetuningSingleDevice:
|
|||
|
||||
# Currently logging only logs limited training metrics to local disk
|
||||
# will figure out more loggings and how it works with telemetry in future PRs
|
||||
|
||||
_checkpointer: TorchtuneCheckpointer
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: TorchtunePostTrainingConfig,
|
||||
|
|
@ -82,7 +85,7 @@ class LoraFinetuningSingleDevice:
|
|||
logger_config: Dict[str, Any],
|
||||
model: str,
|
||||
checkpoint_dir: Optional[str],
|
||||
algorithm_config: Optional[AlgorithmConfig],
|
||||
algorithm_config: LoraFinetuningConfig | QATFinetuningConfig | None,
|
||||
datasetio_api: DatasetIO,
|
||||
datasets_api: Datasets,
|
||||
) -> None:
|
||||
|
|
@ -109,12 +112,12 @@ class LoraFinetuningSingleDevice:
|
|||
return str(checkpoint_dir)
|
||||
|
||||
if checkpoint_dir and checkpoint_dir != "null":
|
||||
self.checkpoint_dir = config.checkpoint_dir
|
||||
self.checkpoint_dir = checkpoint_dir
|
||||
else:
|
||||
model = resolve_model(self.model_id)
|
||||
if model is None:
|
||||
model_obj = resolve_model(self.model_id)
|
||||
if model_obj is None:
|
||||
raise ValueError(f"{self.model_id} not found. Your model id should be in the llama models SKU list")
|
||||
self.checkpoint_dir = model_checkpoint_dir(model)
|
||||
self.checkpoint_dir = model_checkpoint_dir(model_obj)
|
||||
|
||||
self._output_dir = str(DEFAULT_CHECKPOINT_DIR)
|
||||
self._checkpoint_format = config.checkpoint_format
|
||||
|
|
@ -135,16 +138,16 @@ class LoraFinetuningSingleDevice:
|
|||
self.max_validation_steps = training_config.max_validation_steps
|
||||
|
||||
self._clip_grad_norm = 1.0
|
||||
self._enable_activation_checkpointing = (
|
||||
(training_config.efficiency_config.enable_activation_checkpointing)
|
||||
if training_config.efficiency_config
|
||||
else False
|
||||
)
|
||||
self._enable_activation_offloading = (
|
||||
(training_config.efficiency_config.enable_activation_offloading)
|
||||
if training_config.efficiency_config
|
||||
else False
|
||||
)
|
||||
|
||||
self._enable_activation_checkpointing = False
|
||||
self._enable_activation_offloading = False
|
||||
if training_config.efficiency_config:
|
||||
if training_config.efficiency_config.enable_activation_checkpointing:
|
||||
self._enable_activation_checkpointing = (
|
||||
training_config.efficiency_config.enable_activation_checkpointing
|
||||
)
|
||||
if training_config.efficiency_config.enable_activation_offloading:
|
||||
self._enable_activation_offloading = training_config.efficiency_config.enable_activation_offloading
|
||||
|
||||
self.datasetio_api = datasetio_api
|
||||
self.datasets_api = datasets_api
|
||||
|
|
@ -328,13 +331,13 @@ class LoraFinetuningSingleDevice:
|
|||
batch_size: int,
|
||||
) -> Tuple[DistributedSampler, DataLoader]:
|
||||
async def fetch_rows(dataset_id: str):
|
||||
return await self.datasetio_api.get_rows_paginated(
|
||||
return await self.datasetio_api.iterrows(
|
||||
dataset_id=dataset_id,
|
||||
rows_in_page=-1,
|
||||
limit=-1,
|
||||
)
|
||||
|
||||
all_rows = await fetch_rows(dataset_id)
|
||||
rows = all_rows.rows
|
||||
rows = all_rows.data
|
||||
|
||||
await validate_input_dataset_schema(
|
||||
datasets_api=self.datasets_api,
|
||||
|
|
@ -451,12 +454,12 @@ class LoraFinetuningSingleDevice:
|
|||
"""
|
||||
# Initialize tokens count and running loss (for grad accumulation)
|
||||
t0 = time.perf_counter()
|
||||
running_loss = 0
|
||||
running_loss: float = 0.0
|
||||
num_tokens = 0
|
||||
|
||||
# training artifacts
|
||||
checkpoints = []
|
||||
memory_stats = {}
|
||||
memory_stats: Dict[str, Any] = {}
|
||||
|
||||
# self.epochs_run should be non-zero when we're resuming from a checkpoint
|
||||
for curr_epoch in range(self.epochs_run, self.total_epochs):
|
||||
|
|
@ -484,7 +487,7 @@ class LoraFinetuningSingleDevice:
|
|||
# Loss is normalized by default so we multiply by the number of tokens
|
||||
# This way we can normalize by the total number of tokens if we're accumulating gradients
|
||||
current_loss = await self._loss_step(batch) * current_num_tokens
|
||||
running_loss += current_loss
|
||||
running_loss += current_loss.detach().item()
|
||||
current_loss.backward()
|
||||
|
||||
# Step with optimizer
|
||||
|
|
@ -500,7 +503,7 @@ class LoraFinetuningSingleDevice:
|
|||
# Update the number of steps when the weights are updated
|
||||
self.global_step += 1
|
||||
|
||||
loss_to_log = running_loss.item() / num_tokens
|
||||
loss_to_log = running_loss / num_tokens
|
||||
|
||||
pbar.update(1)
|
||||
pbar.set_description(f"{curr_epoch + 1}|{self.global_step}|Loss: {loss_to_log}")
|
||||
|
|
@ -523,7 +526,7 @@ class LoraFinetuningSingleDevice:
|
|||
)
|
||||
|
||||
# Reset running stats for the next step
|
||||
running_loss = 0
|
||||
running_loss = 0.0
|
||||
num_tokens = 0
|
||||
t0 = time.perf_counter()
|
||||
|
||||
|
|
@ -532,7 +535,7 @@ class LoraFinetuningSingleDevice:
|
|||
checkpoint_path = await self.save_checkpoint(epoch=curr_epoch)
|
||||
checkpoint = Checkpoint(
|
||||
identifier=f"{self.model_id}-sft-{curr_epoch}",
|
||||
created_at=datetime.now(),
|
||||
created_at=datetime.now(timezone.utc),
|
||||
epoch=curr_epoch,
|
||||
post_training_job_id=self.job_uuid,
|
||||
path=checkpoint_path,
|
||||
|
|
|
|||
|
|
@ -4,10 +4,12 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import Any, Dict
|
||||
|
||||
from .config import CodeScannerConfig
|
||||
|
||||
|
||||
async def get_provider_impl(config: CodeScannerConfig, deps):
|
||||
async def get_provider_impl(config: CodeScannerConfig, deps: Dict[str, Any]):
|
||||
from .code_scanner import MetaReferenceCodeScannerSafetyImpl
|
||||
|
||||
impl = MetaReferenceCodeScannerSafetyImpl(config, deps)
|
||||
|
|
|
|||
|
|
@ -4,8 +4,12 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import Any, Dict
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class CodeScannerConfig(BaseModel):
|
||||
pass
|
||||
@classmethod
|
||||
def sample_run_config(cls, __distro_dir__: str, **kwargs: Any) -> Dict[str, Any]:
|
||||
return {}
|
||||
|
|
|
|||
|
|
@ -4,10 +4,12 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import Any, Dict
|
||||
|
||||
from .config import LlamaGuardConfig
|
||||
|
||||
|
||||
async def get_provider_impl(config: LlamaGuardConfig, deps):
|
||||
async def get_provider_impl(config: LlamaGuardConfig, deps: Dict[str, Any]):
|
||||
from .llama_guard import LlamaGuardSafetyImpl
|
||||
|
||||
assert isinstance(config, LlamaGuardConfig), f"Unexpected config type: {type(config)}"
|
||||
|
|
|
|||
|
|
@ -4,10 +4,16 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import List
|
||||
from typing import Any, Dict, List
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class LlamaGuardConfig(BaseModel):
|
||||
excluded_categories: List[str] = []
|
||||
|
||||
@classmethod
|
||||
def sample_run_config(cls, __distro_dir__: str, **kwargs: Any) -> Dict[str, Any]:
|
||||
return {
|
||||
"excluded_categories": [],
|
||||
}
|
||||
|
|
|
|||
|
|
@ -227,13 +227,6 @@ class LlamaGuardShield:
|
|||
if len(messages) >= 2 and (messages[0].role == Role.user.value and messages[1].role == Role.user.value):
|
||||
messages = messages[1:]
|
||||
|
||||
for i in range(1, len(messages)):
|
||||
if messages[i].role == messages[i - 1].role:
|
||||
for i, m in enumerate(messages):
|
||||
print(f"{i}: {m.role}: {m.content}")
|
||||
raise ValueError(
|
||||
f"Messages must alternate between user and assistant. Message {i} has the same role as message {i - 1}"
|
||||
)
|
||||
return messages
|
||||
|
||||
async def run(self, messages: List[Message]) -> RunShieldResponse:
|
||||
|
|
|
|||
|
|
@ -4,10 +4,12 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import Any, Dict
|
||||
|
||||
from .config import PromptGuardConfig # noqa: F401
|
||||
|
||||
|
||||
async def get_provider_impl(config: PromptGuardConfig, deps):
|
||||
async def get_provider_impl(config: PromptGuardConfig, deps: Dict[str, Any]):
|
||||
from .prompt_guard import PromptGuardSafetyImpl
|
||||
|
||||
impl = PromptGuardSafetyImpl(config, deps)
|
||||
|
|
|
|||
|
|
@ -5,6 +5,7 @@
|
|||
# the root directory of this source tree.
|
||||
|
||||
from enum import Enum
|
||||
from typing import Any, Dict
|
||||
|
||||
from pydantic import BaseModel, field_validator
|
||||
|
||||
|
|
@ -23,3 +24,9 @@ class PromptGuardConfig(BaseModel):
|
|||
if v not in [t.value for t in PromptGuardType]:
|
||||
raise ValueError(f"Unknown prompt guard type: {v}")
|
||||
return v
|
||||
|
||||
@classmethod
|
||||
def sample_run_config(cls, __distro_dir__: str, **kwargs: Any) -> Dict[str, Any]:
|
||||
return {
|
||||
"guard_type": "injection",
|
||||
}
|
||||
|
|
|
|||
|
|
@ -3,16 +3,16 @@
|
|||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
from typing import Dict
|
||||
from typing import Any, Dict
|
||||
|
||||
from llama_stack.distribution.datatypes import Api, ProviderSpec
|
||||
from llama_stack.distribution.datatypes import Api
|
||||
|
||||
from .config import BasicScoringConfig
|
||||
|
||||
|
||||
async def get_provider_impl(
|
||||
config: BasicScoringConfig,
|
||||
deps: Dict[Api, ProviderSpec],
|
||||
deps: Dict[Api, Any],
|
||||
):
|
||||
from .scoring import BasicScoringImpl
|
||||
|
||||
|
|
|
|||
|
|
@ -3,7 +3,12 @@
|
|||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
from typing import Any, Dict
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class BasicScoringConfig(BaseModel): ...
|
||||
class BasicScoringConfig(BaseModel):
|
||||
@classmethod
|
||||
def sample_run_config(cls, __distro_dir__: str, **kwargs: Any) -> Dict[str, Any]:
|
||||
return {}
|
||||
|
|
|
|||
|
|
@ -22,11 +22,25 @@ from llama_stack.providers.utils.common.data_schema_validator import (
|
|||
)
|
||||
|
||||
from .config import BasicScoringConfig
|
||||
from .scoring_fn.bfcl_scoring_fn import BFCLScoringFn
|
||||
from .scoring_fn.docvqa_scoring_fn import DocVQAScoringFn
|
||||
from .scoring_fn.equality_scoring_fn import EqualityScoringFn
|
||||
from .scoring_fn.ifeval_scoring_fn import IfEvalScoringFn
|
||||
from .scoring_fn.regex_parser_math_response_scoring_fn import (
|
||||
RegexParserMathResponseScoringFn,
|
||||
)
|
||||
from .scoring_fn.regex_parser_scoring_fn import RegexParserScoringFn
|
||||
from .scoring_fn.subset_of_scoring_fn import SubsetOfScoringFn
|
||||
|
||||
FIXED_FNS = [EqualityScoringFn, SubsetOfScoringFn, RegexParserScoringFn]
|
||||
FIXED_FNS = [
|
||||
EqualityScoringFn,
|
||||
SubsetOfScoringFn,
|
||||
RegexParserScoringFn,
|
||||
RegexParserMathResponseScoringFn,
|
||||
BFCLScoringFn,
|
||||
IfEvalScoringFn,
|
||||
DocVQAScoringFn,
|
||||
]
|
||||
|
||||
|
||||
class BasicScoringImpl(
|
||||
|
|
@ -74,12 +88,12 @@ class BasicScoringImpl(
|
|||
dataset_def = await self.datasets_api.get_dataset(dataset_id=dataset_id)
|
||||
validate_dataset_schema(dataset_def.dataset_schema, get_valid_schemas(Api.scoring.value))
|
||||
|
||||
all_rows = await self.datasetio_api.get_rows_paginated(
|
||||
all_rows = await self.datasetio_api.iterrows(
|
||||
dataset_id=dataset_id,
|
||||
rows_in_page=-1,
|
||||
limit=-1,
|
||||
)
|
||||
res = await self.score(
|
||||
input_rows=all_rows.rows,
|
||||
input_rows=all_rows.data,
|
||||
scoring_functions=scoring_functions,
|
||||
)
|
||||
if save_results_dataset:
|
||||
|
|
|
|||
|
|
@ -0,0 +1,93 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import json
|
||||
import re
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
from llama_stack.apis.scoring import ScoringResultRow
|
||||
from llama_stack.apis.scoring_functions import ScoringFnParams
|
||||
from llama_stack.providers.utils.scoring.base_scoring_fn import RegisteredBaseScoringFn
|
||||
|
||||
from ..utils.bfcl.ast_parser import decode_ast
|
||||
from ..utils.bfcl.checker import ast_checker, is_empty_output
|
||||
from .fn_defs.bfcl import bfcl
|
||||
|
||||
|
||||
def postprocess(x: Dict[str, Any], test_category: str) -> Dict[str, Any]:
|
||||
contain_func_call = False
|
||||
error = None
|
||||
error_type = None
|
||||
checker_result = {}
|
||||
try:
|
||||
prediction = decode_ast(x["generated_answer"], x["language"]) or ""
|
||||
contain_func_call = True
|
||||
# if not is_function_calling_format_output(prediction):
|
||||
if is_empty_output(prediction):
|
||||
contain_func_call = False
|
||||
error = "Did not output in the specified format. Note: the model_result is wrapped in a string to ensure json serializability."
|
||||
error_type = "ast_decoder:decoder_wrong_output_format"
|
||||
else:
|
||||
checker_result = ast_checker(
|
||||
json.loads(x["function"]),
|
||||
prediction,
|
||||
json.loads(x["ground_truth"]),
|
||||
x["language"],
|
||||
test_category=test_category,
|
||||
model_name="",
|
||||
)
|
||||
except Exception as e:
|
||||
prediction = ""
|
||||
error = f"Invalid syntax. Failed to decode AST. {str(e)}"
|
||||
error_type = "ast_decoder:decoder_failed"
|
||||
return {
|
||||
"prediction": prediction,
|
||||
"contain_func_call": contain_func_call,
|
||||
"valid": checker_result.get("valid", False),
|
||||
"error": error or checker_result.get("error", ""),
|
||||
"error_type": error_type or checker_result.get("error_type", ""),
|
||||
}
|
||||
|
||||
|
||||
def gen_valid(x: Dict[str, Any]) -> Dict[str, float]:
|
||||
return {"valid": x["valid"]}
|
||||
|
||||
|
||||
def gen_relevance_acc(x: Dict[str, Any]) -> Dict[str, float]:
|
||||
# This function serves for both relevance and irrelevance tests, which share the exact opposite logic.
|
||||
# If `test_category` is "irrelevance", the model is expected to output no function call.
|
||||
# No function call means either the AST decoding fails (a error message is generated) or the decoded AST does not contain any function call (such as a empty list, `[]`).
|
||||
# If `test_category` is "relevance", the model is expected to output to a function call, and empty list doesn't count as a function call.
|
||||
acc = not x["contain_func_call"] if "irrelevance" in x["id"] else x["contain_func_call"]
|
||||
return {"valid": float(acc)}
|
||||
|
||||
|
||||
class BFCLScoringFn(RegisteredBaseScoringFn):
|
||||
"""
|
||||
A scoring_fn for BFCL
|
||||
"""
|
||||
|
||||
def __init__(self, *args, **kwargs) -> None:
|
||||
super().__init__(*args, **kwargs)
|
||||
self.supported_fn_defs_registry = {
|
||||
bfcl.identifier: bfcl,
|
||||
}
|
||||
|
||||
async def score_row(
|
||||
self,
|
||||
input_row: Dict[str, Any],
|
||||
scoring_fn_identifier: Optional[str] = "bfcl",
|
||||
scoring_params: Optional[ScoringFnParams] = None,
|
||||
) -> ScoringResultRow:
|
||||
test_category = re.sub(r"_[0-9_-]+$", "", input_row["id"])
|
||||
score_result = postprocess(input_row, test_category)
|
||||
if test_category in {"irrelevance", "live_relevance", "live_irrelevance"}:
|
||||
score = gen_relevance_acc(score_result)["valid"]
|
||||
else:
|
||||
score = gen_valid(score_result)["valid"]
|
||||
return {
|
||||
"score": float(score),
|
||||
}
|
||||
|
|
@ -0,0 +1,240 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import json
|
||||
import re
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
from llama_stack.apis.scoring import ScoringResultRow
|
||||
from llama_stack.apis.scoring_functions import ScoringFnParams
|
||||
from llama_stack.providers.utils.scoring.base_scoring_fn import RegisteredBaseScoringFn
|
||||
|
||||
from .fn_defs.docvqa import docvqa
|
||||
|
||||
CONTRACTIONS = {
|
||||
"aint": "ain't",
|
||||
"arent": "aren't",
|
||||
"cant": "can't",
|
||||
"couldve": "could've",
|
||||
"couldnt": "couldn't",
|
||||
"couldn'tve": "couldn't've",
|
||||
"couldnt've": "couldn't've",
|
||||
"didnt": "didn't",
|
||||
"doesnt": "doesn't",
|
||||
"dont": "don't",
|
||||
"hadnt": "hadn't",
|
||||
"hadnt've": "hadn't've",
|
||||
"hadn'tve": "hadn't've",
|
||||
"hasnt": "hasn't",
|
||||
"havent": "haven't",
|
||||
"hed": "he'd",
|
||||
"hed've": "he'd've",
|
||||
"he'dve": "he'd've",
|
||||
"hes": "he's",
|
||||
"howd": "how'd",
|
||||
"howll": "how'll",
|
||||
"hows": "how's",
|
||||
"Id've": "I'd've",
|
||||
"I'dve": "I'd've",
|
||||
"Im": "I'm",
|
||||
"Ive": "I've",
|
||||
"isnt": "isn't",
|
||||
"itd": "it'd",
|
||||
"itd've": "it'd've",
|
||||
"it'dve": "it'd've",
|
||||
"itll": "it'll",
|
||||
"let's": "let's",
|
||||
"maam": "ma'am",
|
||||
"mightnt": "mightn't",
|
||||
"mightnt've": "mightn't've",
|
||||
"mightn'tve": "mightn't've",
|
||||
"mightve": "might've",
|
||||
"mustnt": "mustn't",
|
||||
"mustve": "must've",
|
||||
"neednt": "needn't",
|
||||
"notve": "not've",
|
||||
"oclock": "o'clock",
|
||||
"oughtnt": "oughtn't",
|
||||
"ow's'at": "'ow's'at",
|
||||
"'ows'at": "'ow's'at",
|
||||
"'ow'sat": "'ow's'at",
|
||||
"shant": "shan't",
|
||||
"shed've": "she'd've",
|
||||
"she'dve": "she'd've",
|
||||
"she's": "she's",
|
||||
"shouldve": "should've",
|
||||
"shouldnt": "shouldn't",
|
||||
"shouldnt've": "shouldn't've",
|
||||
"shouldn'tve": "shouldn't've",
|
||||
"somebody'd": "somebodyd",
|
||||
"somebodyd've": "somebody'd've",
|
||||
"somebody'dve": "somebody'd've",
|
||||
"somebodyll": "somebody'll",
|
||||
"somebodys": "somebody's",
|
||||
"someoned": "someone'd",
|
||||
"someoned've": "someone'd've",
|
||||
"someone'dve": "someone'd've",
|
||||
"someonell": "someone'll",
|
||||
"someones": "someone's",
|
||||
"somethingd": "something'd",
|
||||
"somethingd've": "something'd've",
|
||||
"something'dve": "something'd've",
|
||||
"somethingll": "something'll",
|
||||
"thats": "that's",
|
||||
"thered": "there'd",
|
||||
"thered've": "there'd've",
|
||||
"there'dve": "there'd've",
|
||||
"therere": "there're",
|
||||
"theres": "there's",
|
||||
"theyd": "they'd",
|
||||
"theyd've": "they'd've",
|
||||
"they'dve": "they'd've",
|
||||
"theyll": "they'll",
|
||||
"theyre": "they're",
|
||||
"theyve": "they've",
|
||||
"twas": "'twas",
|
||||
"wasnt": "wasn't",
|
||||
"wed've": "we'd've",
|
||||
"we'dve": "we'd've",
|
||||
"weve": "we've",
|
||||
"werent": "weren't",
|
||||
"whatll": "what'll",
|
||||
"whatre": "what're",
|
||||
"whats": "what's",
|
||||
"whatve": "what've",
|
||||
"whens": "when's",
|
||||
"whered": "where'd",
|
||||
"wheres": "where's",
|
||||
"whereve": "where've",
|
||||
"whod": "who'd",
|
||||
"whod've": "who'd've",
|
||||
"who'dve": "who'd've",
|
||||
"wholl": "who'll",
|
||||
"whos": "who's",
|
||||
"whove": "who've",
|
||||
"whyll": "why'll",
|
||||
"whyre": "why're",
|
||||
"whys": "why's",
|
||||
"wont": "won't",
|
||||
"wouldve": "would've",
|
||||
"wouldnt": "wouldn't",
|
||||
"wouldnt've": "wouldn't've",
|
||||
"wouldn'tve": "wouldn't've",
|
||||
"yall": "y'all",
|
||||
"yall'll": "y'all'll",
|
||||
"y'allll": "y'all'll",
|
||||
"yall'd've": "y'all'd've",
|
||||
"y'alld've": "y'all'd've",
|
||||
"y'all'dve": "y'all'd've",
|
||||
"youd": "you'd",
|
||||
"youd've": "you'd've",
|
||||
"you'dve": "you'd've",
|
||||
"youll": "you'll",
|
||||
"youre": "you're",
|
||||
"youve": "you've",
|
||||
"1st": "first",
|
||||
"2nd": "second",
|
||||
"3rd": "third",
|
||||
}
|
||||
NUMBERS = {
|
||||
"none": "0",
|
||||
"zero": "0",
|
||||
"one": "1",
|
||||
"two": "2",
|
||||
"three": "3",
|
||||
"four": "4",
|
||||
"five": "5",
|
||||
"six": "6",
|
||||
"seven": "7",
|
||||
"eight": "8",
|
||||
"nine": "9",
|
||||
"ten": "10",
|
||||
}
|
||||
ARTICLES = [
|
||||
"a",
|
||||
"an",
|
||||
"the",
|
||||
"to",
|
||||
"in",
|
||||
"from",
|
||||
"by",
|
||||
] # Contains a bit more than just articles, but we want to get rid of these elements influencing the accuracy
|
||||
PERIOD_STRIP = re.compile(r"(?!<=\d)(\.)(?!\d)")
|
||||
COMMA_STRIP = re.compile(r"(\d)(\,)(\d)")
|
||||
PUNCTUATION = [
|
||||
";",
|
||||
r"/",
|
||||
"[",
|
||||
"]",
|
||||
'"',
|
||||
"{",
|
||||
"}",
|
||||
"(",
|
||||
")",
|
||||
"=",
|
||||
"+",
|
||||
"\\",
|
||||
"_",
|
||||
"-",
|
||||
">",
|
||||
"<",
|
||||
"@",
|
||||
"`",
|
||||
",",
|
||||
"?",
|
||||
"!",
|
||||
]
|
||||
|
||||
|
||||
def normalize_answer(s: str) -> str:
|
||||
# process punctuation
|
||||
for p in PUNCTUATION:
|
||||
if (p + " " in s or " " + p in s) or (re.search(COMMA_STRIP, s) is not None):
|
||||
s = s.replace(p, "")
|
||||
else:
|
||||
s = s.replace(p, " ")
|
||||
s = PERIOD_STRIP.sub("", s, re.UNICODE)
|
||||
|
||||
# process digits and articles
|
||||
temp_text = s.lower().split()
|
||||
out_text = []
|
||||
for word in temp_text:
|
||||
word = NUMBERS.setdefault(word, word)
|
||||
if word not in ARTICLES:
|
||||
out_text.append(word)
|
||||
|
||||
# standardize contractions
|
||||
for word_id, word in enumerate(out_text):
|
||||
if word in CONTRACTIONS:
|
||||
out_text[word_id] = CONTRACTIONS[word]
|
||||
return " ".join(out_text)
|
||||
|
||||
|
||||
class DocVQAScoringFn(RegisteredBaseScoringFn):
|
||||
"""
|
||||
docvqa basically matches the generated answer against several allowed
|
||||
choices, but we need to normalize the answer to avoid penalizing
|
||||
trivial differences
|
||||
"""
|
||||
|
||||
def __init__(self, *args, **kwargs) -> None:
|
||||
super().__init__(*args, **kwargs)
|
||||
self.supported_fn_defs_registry = {
|
||||
docvqa.identifier: docvqa,
|
||||
}
|
||||
|
||||
async def score_row(
|
||||
self,
|
||||
input_row: Dict[str, Any],
|
||||
scoring_fn_identifier: Optional[str] = "docvqa",
|
||||
scoring_params: Optional[ScoringFnParams] = None,
|
||||
) -> ScoringResultRow:
|
||||
expected_answers = json.loads(input_row["expected_answer"])
|
||||
generated_answer = input_row["generated_answer"]
|
||||
score = 1.0 if normalize_answer(generated_answer) in [normalize_answer(s) for s in expected_answers] else 0.0
|
||||
return {
|
||||
"score": score,
|
||||
}
|
||||
|
|
@ -0,0 +1,21 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from llama_stack.apis.common.type_system import NumberType
|
||||
from llama_stack.apis.scoring_functions import (
|
||||
AggregationFunctionType,
|
||||
BasicScoringFnParams,
|
||||
ScoringFn,
|
||||
)
|
||||
|
||||
bfcl = ScoringFn(
|
||||
identifier="basic::bfcl",
|
||||
description="BFCL complex scoring",
|
||||
return_type=NumberType(),
|
||||
provider_id="basic",
|
||||
provider_resource_id="bfcl",
|
||||
params=BasicScoringFnParams(aggregation_functions=[AggregationFunctionType.accuracy]),
|
||||
)
|
||||
|
|
@ -0,0 +1,21 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from llama_stack.apis.common.type_system import NumberType
|
||||
from llama_stack.apis.scoring_functions import (
|
||||
AggregationFunctionType,
|
||||
BasicScoringFnParams,
|
||||
ScoringFn,
|
||||
)
|
||||
|
||||
docvqa = ScoringFn(
|
||||
identifier="basic::docvqa",
|
||||
description="DocVQA Visual Question & Answer scoring function",
|
||||
return_type=NumberType(),
|
||||
provider_id="basic",
|
||||
provider_resource_id="docvqa",
|
||||
params=BasicScoringFnParams(aggregation_functions=[AggregationFunctionType.accuracy]),
|
||||
)
|
||||
|
|
@ -0,0 +1,23 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from llama_stack.apis.common.type_system import NumberType
|
||||
from llama_stack.apis.scoring_functions import (
|
||||
AggregationFunctionType,
|
||||
BasicScoringFnParams,
|
||||
ScoringFn,
|
||||
)
|
||||
|
||||
ifeval = ScoringFn(
|
||||
identifier="basic::ifeval",
|
||||
description="Eval intruction follow capacity by checkping how many instructions can be followed in each example",
|
||||
return_type=NumberType(),
|
||||
provider_id="basic",
|
||||
provider_resource_id="ifeval",
|
||||
params=BasicScoringFnParams(
|
||||
aggregation_functions=[AggregationFunctionType.weighted_average],
|
||||
),
|
||||
)
|
||||
|
|
@ -0,0 +1,27 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from llama_stack.apis.common.type_system import NumberType
|
||||
from llama_stack.apis.scoring_functions import (
|
||||
AggregationFunctionType,
|
||||
RegexParserScoringFnParams,
|
||||
ScoringFn,
|
||||
)
|
||||
|
||||
MATH_ANSWER_REGEXES = [r".*final answer is:?\s*\$\\boxed{(?P<X>.*)}\$"]
|
||||
|
||||
|
||||
regex_parser_math_response = ScoringFn(
|
||||
identifier="basic::regex_parser_math_response",
|
||||
description="For math related benchmarks, extract answer from the generated response and expected_answer and see if they match",
|
||||
return_type=NumberType(),
|
||||
provider_id="basic",
|
||||
provider_resource_id="regex-parser-math-response",
|
||||
params=RegexParserScoringFnParams(
|
||||
parsing_regexes=MATH_ANSWER_REGEXES,
|
||||
aggregation_functions=[AggregationFunctionType.accuracy],
|
||||
),
|
||||
)
|
||||
|
|
@ -0,0 +1,79 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
from llama_stack.apis.scoring import ScoringResultRow
|
||||
from llama_stack.apis.scoring_functions import ScoringFnParams
|
||||
from llama_stack.providers.utils.scoring.base_scoring_fn import RegisteredBaseScoringFn
|
||||
|
||||
from ..utils.ifeval_utils import INSTRUCTION_DICT, INSTRUCTION_LIST
|
||||
from .fn_defs.ifeval import (
|
||||
ifeval,
|
||||
)
|
||||
|
||||
|
||||
class IfEvalScoringFn(RegisteredBaseScoringFn):
|
||||
"""
|
||||
A scoring_fn Instruction-Following Eval (IFEval) benchmark
|
||||
"""
|
||||
|
||||
def __init__(self, *args, **kwargs) -> None:
|
||||
super().__init__(*args, **kwargs)
|
||||
self.supported_fn_defs_registry = {
|
||||
ifeval.identifier: ifeval,
|
||||
}
|
||||
|
||||
async def score_row(
|
||||
self,
|
||||
input_row: Dict[str, Any],
|
||||
scoring_fn_identifier: Optional[str] = None,
|
||||
scoring_params: Optional[ScoringFnParams] = None,
|
||||
) -> ScoringResultRow:
|
||||
assert scoring_fn_identifier is not None, "Scoring function identifier not found."
|
||||
fn_def = self.supported_fn_defs_registry[scoring_fn_identifier]
|
||||
if scoring_params is not None:
|
||||
fn_def.params = scoring_params
|
||||
|
||||
instruction_list = input_row["instruction_id_list"]
|
||||
generated_answer = input_row["generated_answer"].strip()
|
||||
|
||||
is_following_list = []
|
||||
results = dict(
|
||||
{k + "_correct": 0.0 for k in INSTRUCTION_LIST},
|
||||
**{k + "_total": 0.0 for k in INSTRUCTION_LIST},
|
||||
)
|
||||
|
||||
for index, instruction_id in enumerate(instruction_list):
|
||||
instruction_cls = INSTRUCTION_DICT[instruction_id]
|
||||
instruction = instruction_cls(instruction_id)
|
||||
results[instruction_id + "_total"] += 1.0
|
||||
results[instruction_id.split(":")[0] + "_total"] += 1.0
|
||||
|
||||
clean_input_row = {k: v for k, v in input_row["kwargs"][index].items() if v is not None}
|
||||
print(clean_input_row)
|
||||
instruction.build_description(**clean_input_row)
|
||||
args = instruction.get_instruction_args()
|
||||
if args and "prompt" in args:
|
||||
instruction.build_description(prompt=input_row["prompt"])
|
||||
|
||||
if generated_answer and instruction.check_following(generated_answer):
|
||||
is_following_list.append(True)
|
||||
results[instruction_id + "_correct"] += 1.0
|
||||
results[instruction_id.split(":")[0] + "_correct"] += 1.0
|
||||
else:
|
||||
is_following_list.append(False)
|
||||
|
||||
if len(is_following_list) == 0:
|
||||
return {
|
||||
"score": 0.0,
|
||||
"weight": 0.0,
|
||||
}
|
||||
|
||||
return {
|
||||
"score": float(sum(is_following_list)) / float(len(is_following_list)),
|
||||
"weight": float(len(is_following_list)),
|
||||
}
|
||||
|
|
@ -0,0 +1,66 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
from llama_stack.apis.scoring import ScoringResultRow
|
||||
from llama_stack.apis.scoring_functions import ScoringFnParams, ScoringFnParamsType
|
||||
from llama_stack.providers.utils.scoring.base_scoring_fn import RegisteredBaseScoringFn
|
||||
|
||||
from ..utils.math_utils import first_answer, normalize_final_answer, try_evaluate_frac, try_evaluate_latex
|
||||
from .fn_defs.regex_parser_math_response import (
|
||||
regex_parser_math_response,
|
||||
)
|
||||
|
||||
|
||||
class RegexParserMathResponseScoringFn(RegisteredBaseScoringFn):
|
||||
"""
|
||||
A scoring_fn for math benchamrks that parses answer from generated response according to context and check match with expected_answer.
|
||||
"""
|
||||
|
||||
def __init__(self, *args, **kwargs) -> None:
|
||||
super().__init__(*args, **kwargs)
|
||||
self.supported_fn_defs_registry = {
|
||||
regex_parser_math_response.identifier: regex_parser_math_response,
|
||||
}
|
||||
|
||||
async def score_row(
|
||||
self,
|
||||
input_row: Dict[str, Any],
|
||||
scoring_fn_identifier: Optional[str] = None,
|
||||
scoring_params: Optional[ScoringFnParams] = None,
|
||||
) -> ScoringResultRow:
|
||||
assert scoring_fn_identifier is not None, "Scoring function identifier not found."
|
||||
fn_def = self.supported_fn_defs_registry[scoring_fn_identifier]
|
||||
if scoring_params is not None:
|
||||
fn_def.params = scoring_params
|
||||
|
||||
assert fn_def.params is not None and fn_def.params.type == ScoringFnParamsType.regex_parser.value, (
|
||||
f"RegexParserScoringFnParams not found for {fn_def}."
|
||||
)
|
||||
|
||||
expected_answer = input_row["expected_answer"]
|
||||
generated_answer = input_row["generated_answer"]
|
||||
|
||||
parsing_regexes = fn_def.params.parsing_regexes
|
||||
assert len(parsing_regexes) == 1, (
|
||||
"Only one parsing regex is supported for regex_parser_math_response scoring function."
|
||||
)
|
||||
parsing_regexes = fn_def.params.parsing_regexes[0]
|
||||
|
||||
normalized_generated_answer = normalize_final_answer(
|
||||
first_answer(generated_answer),
|
||||
parsing_regexes,
|
||||
match_first=True,
|
||||
)
|
||||
normalized_generated_answer = try_evaluate_frac(try_evaluate_latex(normalized_generated_answer))
|
||||
|
||||
normalized_expected_answer = normalize_final_answer(expected_answer, r".*")
|
||||
normalized_expected_answer = try_evaluate_frac(try_evaluate_latex(normalized_expected_answer))
|
||||
|
||||
score = 1.0 if normalized_generated_answer == normalized_expected_answer else 0.0
|
||||
return {
|
||||
"score": score,
|
||||
}
|
||||
|
|
@ -3,10 +3,3 @@
|
|||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class SampleConfig(BaseModel):
|
||||
host: str = "localhost"
|
||||
port: int = 9999
|
||||
|
|
@ -0,0 +1,296 @@
|
|||
# ruff: noqa
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
import ast
|
||||
|
||||
from .tree_sitter import get_parser
|
||||
|
||||
|
||||
def parse_java_function_call(source_code):
|
||||
if not source_code.endswith(";"):
|
||||
source_code += ";" # Necessary for the parser not to register an error
|
||||
parser = get_parser("java")
|
||||
tree = parser.parse(bytes(source_code, "utf8"))
|
||||
root_node = tree.root_node
|
||||
|
||||
if root_node.has_error:
|
||||
raise Exception("Error parsing java the source code.")
|
||||
|
||||
def get_text(node):
|
||||
"""Returns the text represented by the node."""
|
||||
return source_code[node.start_byte : node.end_byte]
|
||||
|
||||
def traverse_node(node, nested=False):
|
||||
if node.type == "string_literal":
|
||||
if nested:
|
||||
return get_text(node)
|
||||
# Strip surrounding quotes from string literals
|
||||
return get_text(node)[1:-1]
|
||||
elif node.type == "character_literal":
|
||||
if nested:
|
||||
return get_text(node)
|
||||
# Strip surrounding single quotes from character literals
|
||||
return get_text(node)[1:-1]
|
||||
"""Traverse the node to collect texts for complex structures."""
|
||||
if node.type in [
|
||||
"identifier",
|
||||
"class_literal",
|
||||
"type_identifier",
|
||||
"method_invocation",
|
||||
]:
|
||||
return get_text(node)
|
||||
elif node.type == "array_creation_expression":
|
||||
# Handle array creation expression specifically
|
||||
type_node = node.child_by_field_name("type")
|
||||
value_node = node.child_by_field_name("value")
|
||||
type_text = traverse_node(type_node, True)
|
||||
value_text = traverse_node(value_node, True)
|
||||
return f"new {type_text}[]{value_text}"
|
||||
elif node.type == "object_creation_expression":
|
||||
# Handle object creation expression specifically
|
||||
type_node = node.child_by_field_name("type")
|
||||
arguments_node = node.child_by_field_name("arguments")
|
||||
type_text = traverse_node(type_node, True)
|
||||
if arguments_node:
|
||||
# Process each argument carefully, avoiding unnecessary punctuation
|
||||
argument_texts = []
|
||||
for child in arguments_node.children:
|
||||
if child.type not in [
|
||||
",",
|
||||
"(",
|
||||
")",
|
||||
]: # Exclude commas and parentheses
|
||||
argument_text = traverse_node(child, True)
|
||||
argument_texts.append(argument_text)
|
||||
arguments_text = ", ".join(argument_texts)
|
||||
return f"new {type_text}({arguments_text})"
|
||||
else:
|
||||
return f"new {type_text}()"
|
||||
elif node.type == "set":
|
||||
# Handling sets specifically
|
||||
items = [traverse_node(n, True) for n in node.children if n.type not in [",", "set"]]
|
||||
return "{" + ", ".join(items) + "}"
|
||||
|
||||
elif node.child_count > 0:
|
||||
return "".join(traverse_node(child, True) for child in node.children)
|
||||
else:
|
||||
return get_text(node)
|
||||
|
||||
def extract_arguments(args_node):
|
||||
arguments = {}
|
||||
for child in args_node.children:
|
||||
if child.type == "assignment_expression":
|
||||
# For named parameters
|
||||
name_node, value_node = child.children[0], child.children[2]
|
||||
name = get_text(name_node)
|
||||
value = traverse_node(value_node)
|
||||
if name in arguments:
|
||||
if not isinstance(arguments[name], list):
|
||||
arguments[name] = [arguments[name]]
|
||||
arguments[name].append(value)
|
||||
else:
|
||||
arguments[name] = value
|
||||
# arguments.append({'name': name, 'value': value})
|
||||
elif child.type in ["identifier", "class_literal", "set"]:
|
||||
# For unnamed parameters and handling sets
|
||||
value = traverse_node(child)
|
||||
if None in arguments:
|
||||
if not isinstance(arguments[None], list):
|
||||
arguments[None] = [arguments[None]]
|
||||
arguments[None].append(value)
|
||||
else:
|
||||
arguments[None] = value
|
||||
return arguments
|
||||
|
||||
def traverse(node):
|
||||
if node.type == "method_invocation":
|
||||
# Extract the function name and its arguments
|
||||
method_name = get_text(node.child_by_field_name("name"))
|
||||
class_name_node = node.child_by_field_name("object")
|
||||
if class_name_node:
|
||||
class_name = get_text(class_name_node)
|
||||
function_name = f"{class_name}.{method_name}"
|
||||
else:
|
||||
function_name = method_name
|
||||
arguments_node = node.child_by_field_name("arguments")
|
||||
if arguments_node:
|
||||
arguments = extract_arguments(arguments_node)
|
||||
for key, value in arguments.items():
|
||||
if isinstance(value, list):
|
||||
raise Exception("Error: Multiple arguments with the same name are not supported.")
|
||||
return [{function_name: arguments}]
|
||||
|
||||
else:
|
||||
for child in node.children:
|
||||
result = traverse(child)
|
||||
if result:
|
||||
return result
|
||||
|
||||
result = traverse(root_node)
|
||||
return result if result else {}
|
||||
|
||||
|
||||
def parse_javascript_function_call(source_code):
|
||||
if not source_code.endswith(";"):
|
||||
source_code += ";" # Necessary for the parser not to register an error
|
||||
parser = get_parser("javascript")
|
||||
# Parse the source code
|
||||
tree = parser.parse(bytes(source_code, "utf8"))
|
||||
root_node = tree.root_node
|
||||
if root_node.has_error:
|
||||
raise Exception("Error js parsing the source code.")
|
||||
|
||||
# Function to recursively extract argument details
|
||||
def extract_arguments(node):
|
||||
args = {}
|
||||
for child in node.children:
|
||||
if child.type == "assignment_expression":
|
||||
# Extract left (name) and right (value) parts of the assignment
|
||||
name = child.children[0].text.decode("utf-8")
|
||||
value = child.children[2].text.decode("utf-8")
|
||||
if (value.startswith('"') and value.endswith('"')) or (value.startswith("'") and value.endswith("'")):
|
||||
value = value[1:-1] # Trim the quotation marks
|
||||
if name in args:
|
||||
if not isinstance(args[name], list):
|
||||
args[name] = [args[name]]
|
||||
args[name].append(value)
|
||||
else:
|
||||
args[name] = value
|
||||
|
||||
elif child.type == "identifier" or child.type == "true":
|
||||
# Handle non-named arguments and boolean values
|
||||
value = child.text.decode("utf-8")
|
||||
if None in args:
|
||||
if not isinstance(args[None], list):
|
||||
args[None] = [args[None]]
|
||||
args[None].append(value)
|
||||
else:
|
||||
args[None] = value
|
||||
return args
|
||||
|
||||
# Find the function call and extract its name and arguments
|
||||
if root_node.type == "program":
|
||||
for child in root_node.children:
|
||||
if child.type == "expression_statement":
|
||||
for sub_child in child.children:
|
||||
if sub_child.type == "call_expression":
|
||||
function_name = sub_child.children[0].text.decode("utf8")
|
||||
arguments_node = sub_child.children[1]
|
||||
parameters = extract_arguments(arguments_node)
|
||||
for key, value in parameters.items():
|
||||
if isinstance(value, list):
|
||||
raise Exception("Error: Multiple arguments with the same name are not supported.")
|
||||
result = [{function_name: parameters}]
|
||||
return result
|
||||
|
||||
|
||||
def ast_parse(input_str, language="Python"):
|
||||
if language == "Python":
|
||||
cleaned_input = input_str.strip("[]'")
|
||||
parsed = ast.parse(cleaned_input, mode="eval")
|
||||
extracted = []
|
||||
if isinstance(parsed.body, ast.Call):
|
||||
extracted.append(resolve_ast_call(parsed.body))
|
||||
else:
|
||||
for elem in parsed.body.elts:
|
||||
extracted.append(resolve_ast_call(elem))
|
||||
return extracted
|
||||
elif language == "Java":
|
||||
return parse_java_function_call(input_str[1:-1]) # Remove the [ and ] from the string
|
||||
elif language == "JavaScript":
|
||||
return parse_javascript_function_call(input_str[1:-1])
|
||||
else:
|
||||
raise NotImplementedError(f"Unsupported language: {language}")
|
||||
|
||||
|
||||
def resolve_ast_call(elem):
|
||||
# Handle nested attributes for deeply nested module paths
|
||||
func_parts = []
|
||||
func_part = elem.func
|
||||
while isinstance(func_part, ast.Attribute):
|
||||
func_parts.append(func_part.attr)
|
||||
func_part = func_part.value
|
||||
if isinstance(func_part, ast.Name):
|
||||
func_parts.append(func_part.id)
|
||||
func_name = ".".join(reversed(func_parts))
|
||||
args_dict = {}
|
||||
# Parse when args are simply passed as an unnamed dictionary arg
|
||||
for arg in elem.args:
|
||||
if isinstance(arg, ast.Dict):
|
||||
for key, value in zip(arg.keys, arg.values):
|
||||
if isinstance(key, ast.Constant):
|
||||
arg_name = key.value
|
||||
output = resolve_ast_by_type(value)
|
||||
args_dict[arg_name] = output
|
||||
for arg in elem.keywords:
|
||||
output = resolve_ast_by_type(arg.value)
|
||||
args_dict[arg.arg] = output
|
||||
return {func_name: args_dict}
|
||||
|
||||
|
||||
def resolve_ast_by_type(value):
|
||||
if isinstance(value, ast.Constant):
|
||||
if value.value is Ellipsis:
|
||||
output = "..."
|
||||
else:
|
||||
output = value.value
|
||||
elif isinstance(value, ast.UnaryOp):
|
||||
output = -value.operand.value
|
||||
elif isinstance(value, ast.List):
|
||||
output = [resolve_ast_by_type(v) for v in value.elts]
|
||||
elif isinstance(value, ast.Dict):
|
||||
output = {resolve_ast_by_type(k): resolve_ast_by_type(v) for k, v in zip(value.keys, value.values)}
|
||||
elif isinstance(value, ast.NameConstant): # Added this condition to handle boolean values
|
||||
output = value.value
|
||||
elif isinstance(value, ast.BinOp): # Added this condition to handle function calls as arguments
|
||||
output = eval(ast.unparse(value))
|
||||
elif isinstance(value, ast.Name):
|
||||
output = value.id
|
||||
elif isinstance(value, ast.Call):
|
||||
if len(value.keywords) == 0:
|
||||
output = ast.unparse(value)
|
||||
else:
|
||||
output = resolve_ast_call(value)
|
||||
elif isinstance(value, ast.Tuple):
|
||||
output = tuple(resolve_ast_by_type(v) for v in value.elts)
|
||||
elif isinstance(value, ast.Lambda):
|
||||
output = eval(ast.unparse(value.body[0].value))
|
||||
elif isinstance(value, ast.Ellipsis):
|
||||
output = "..."
|
||||
elif isinstance(value, ast.Subscript):
|
||||
try:
|
||||
output = ast.unparse(value.body[0].value)
|
||||
except:
|
||||
output = ast.unparse(value.value) + "[" + ast.unparse(value.slice) + "]"
|
||||
else:
|
||||
raise Exception(f"Unsupported AST type: {type(value)}")
|
||||
return output
|
||||
|
||||
|
||||
def decode_ast(result, language="Python"):
|
||||
func = result
|
||||
func = func.replace("\n", "") # remove new line characters
|
||||
if not func.startswith("["):
|
||||
func = "[" + func
|
||||
if not func.endswith("]"):
|
||||
func = func + "]"
|
||||
decoded_output = ast_parse(func, language)
|
||||
return decoded_output
|
||||
|
||||
|
||||
def decode_execute(result):
|
||||
func = result
|
||||
func = func.replace("\n", "") # remove new line characters
|
||||
if not func.startswith("["):
|
||||
func = "[" + func
|
||||
if not func.endswith("]"):
|
||||
func = func + "]"
|
||||
decode_output = ast_parse(func)
|
||||
execution_list = []
|
||||
for function_call in decode_output:
|
||||
for key, value in function_call.items():
|
||||
execution_list.append(f"{key}({','.join([f'{k}={repr(v)}' for k, v in value.items()])})")
|
||||
return execution_list
|
||||
989
llama_stack/providers/inline/scoring/basic/utils/bfcl/checker.py
Normal file
989
llama_stack/providers/inline/scoring/basic/utils/bfcl/checker.py
Normal file
|
|
@ -0,0 +1,989 @@
|
|||
# ruff: noqa
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
import json
|
||||
import re
|
||||
import time
|
||||
from typing import Any
|
||||
|
||||
# Comment out for now until we actually use the rest checker in evals
|
||||
# import requests # Do not remove this import even though it seems to be unused. It's used in the executable_checker_rest function.
|
||||
|
||||
|
||||
class NoAPIKeyError(Exception):
|
||||
def __init__(self):
|
||||
self.message = "❗️Please fill in the API keys in the function_credential_config.json file. If you do not provide the API keys, the executable test category results will be inaccurate."
|
||||
super().__init__(self.message)
|
||||
|
||||
|
||||
REAL_TIME_MATCH_ALLOWED_DIFFERENCE = 0.2
|
||||
|
||||
|
||||
JAVA_TYPE_CONVERSION = {
|
||||
"byte": int,
|
||||
"short": int,
|
||||
"integer": int,
|
||||
"float": float,
|
||||
"double": float,
|
||||
"long": int,
|
||||
"boolean": bool,
|
||||
"char": str,
|
||||
"Array": list,
|
||||
"ArrayList": list,
|
||||
"Set": set,
|
||||
"HashMap": dict,
|
||||
"Hashtable": dict,
|
||||
"Queue": list, # this can be `queue.Queue` as well, for simplicity we check with list
|
||||
"Stack": list,
|
||||
"String": str,
|
||||
"any": str,
|
||||
}
|
||||
|
||||
JS_TYPE_CONVERSION = {
|
||||
"String": str,
|
||||
"integer": int,
|
||||
"float": float,
|
||||
"Bigint": int,
|
||||
"Boolean": bool,
|
||||
"dict": dict,
|
||||
"array": list,
|
||||
"any": str,
|
||||
}
|
||||
|
||||
# We switch to conditional import for the following two imports to avoid unnecessary installations.
|
||||
# User doesn't need to setup the tree-sitter packages if they are not running the test for that language.
|
||||
# from js_type_converter import js_type_converter
|
||||
# from java_type_converter import java_type_converter
|
||||
|
||||
PYTHON_TYPE_MAPPING = {
|
||||
"string": str,
|
||||
"integer": int,
|
||||
"float": float,
|
||||
"boolean": bool,
|
||||
"array": list,
|
||||
"tuple": list,
|
||||
"dict": dict,
|
||||
"any": str,
|
||||
}
|
||||
|
||||
# This is the list of types that we need to recursively check its values
|
||||
PYTHON_NESTED_TYPE_CHECK_LIST = ["array", "tuple"]
|
||||
|
||||
|
||||
NESTED_CONVERSION_TYPE_LIST = ["Array", "ArrayList", "array"]
|
||||
|
||||
|
||||
#### Helper functions for AST ####
|
||||
def find_description(func_descriptions, name):
|
||||
if type(func_descriptions) == list:
|
||||
for func_description in func_descriptions:
|
||||
if func_description["name"] == name:
|
||||
return func_description
|
||||
return None
|
||||
else:
|
||||
# it is a dict, there is only one function
|
||||
return func_descriptions
|
||||
|
||||
|
||||
def get_possible_answer_type(possible_answer: list):
|
||||
for answer in possible_answer:
|
||||
if answer != "": # Optional parameter
|
||||
return type(answer)
|
||||
return None
|
||||
|
||||
|
||||
def type_checker(
|
||||
param: str,
|
||||
value,
|
||||
possible_answer: list,
|
||||
expected_type_description: str,
|
||||
expected_type_converted,
|
||||
nested_type_converted,
|
||||
):
|
||||
# NOTE: This type checker only supports nested type checking for one level deep.
|
||||
# We didn't implement recursive type checking for nested types, as it's not needed for the current use case and it's very complex.
|
||||
|
||||
result: Any = {
|
||||
"valid": True,
|
||||
"error": [],
|
||||
"is_variable": False,
|
||||
"error_type": "type_error:simple",
|
||||
}
|
||||
|
||||
is_variable = False
|
||||
# check for the case where a variable is used instead of a actual value.
|
||||
# use the type in possible_answer as the expected type
|
||||
possible_answer_type = get_possible_answer_type(possible_answer)
|
||||
# if possible_answer only contains optional parameters, we can't determine the type
|
||||
if possible_answer_type != None:
|
||||
# we are being precise here.
|
||||
# in fact, possible_answer_type should always be string, as that's how we treat varibale in possible_answer
|
||||
if possible_answer_type != expected_type_converted:
|
||||
is_variable = True
|
||||
|
||||
# value is the same type as in function description
|
||||
if type(value) == expected_type_converted:
|
||||
# We don't need to do recursive check for simple types
|
||||
if nested_type_converted == None:
|
||||
result["is_variable"] = is_variable
|
||||
return result
|
||||
else:
|
||||
for possible_answer_item in possible_answer:
|
||||
flag = True # Each parameter should match to at least one possible answer type.
|
||||
# Here, we assume that each item should be the same type. We could also relax it.
|
||||
if type(possible_answer_item) == list:
|
||||
for value_item in value:
|
||||
checker_result = type_checker(
|
||||
param,
|
||||
value_item,
|
||||
possible_answer_item,
|
||||
str(nested_type_converted),
|
||||
nested_type_converted,
|
||||
None,
|
||||
)
|
||||
if not checker_result["valid"]:
|
||||
flag = False
|
||||
break
|
||||
|
||||
if flag:
|
||||
return {"valid": True, "error": [], "is_variable": is_variable}
|
||||
|
||||
result["valid"] = False
|
||||
result["error"] = [
|
||||
f"Nested type checking failed for parameter {repr(param)}. Expected outer type {expected_type_description} with inner type {str(nested_type_converted)}. Parameter value: {repr(value)}."
|
||||
]
|
||||
result["error_type"] = "type_error:nested"
|
||||
|
||||
# value is not as expected, check for the case where a variable is used instead of a actual value
|
||||
# use the type in possible_answer as the expected type
|
||||
possible_answer_type = get_possible_answer_type(possible_answer)
|
||||
# if possible_answer only contains optional parameters, we can't determine the type
|
||||
if possible_answer_type != None:
|
||||
# we are being precise here.
|
||||
# in fact, possible_answer_type should always be string, as that's how we treat varibale in possible_answer
|
||||
if type(value) == possible_answer_type:
|
||||
result["is_variable"] = True
|
||||
return result
|
||||
|
||||
result["valid"] = False
|
||||
result["error"].append(
|
||||
f"Incorrect type for parameter {repr(param)}. Expected type {expected_type_description}, got {type(value).__name__}. Parameter value: {repr(value)}."
|
||||
)
|
||||
result["error_type"] = "type_error:simple"
|
||||
return result
|
||||
|
||||
|
||||
def standardize_string(input_string: str):
|
||||
# This function standardizes the string by removing all the spaces, ",./-_*^" punctuation, and converting it to lowercase
|
||||
# It will also convert all the single quotes to double quotes
|
||||
# This is used to compare the model output with the possible answers
|
||||
# We don't want to punish model for answer like April 1, 2024 vs April 1,2024, vs April 1 2024
|
||||
regex_string = r"[ \,\.\/\-\_\*\^]"
|
||||
return re.sub(regex_string, "", input_string).lower().replace("'", '"')
|
||||
|
||||
|
||||
def string_checker(param: str, model_output: str, possible_answer: list):
|
||||
standardize_possible_answer = []
|
||||
standardize_model_output = standardize_string(model_output)
|
||||
for i in range(len(possible_answer)):
|
||||
if type(possible_answer[i]) == str:
|
||||
standardize_possible_answer.append(standardize_string(possible_answer[i]))
|
||||
|
||||
if standardize_model_output not in standardize_possible_answer:
|
||||
return {
|
||||
"valid": False,
|
||||
"error": [
|
||||
f"Invalid value for parameter {repr(param)}: {repr(model_output)}. Expected one of {possible_answer}. Case insensitive."
|
||||
],
|
||||
"error_type": "value_error:string",
|
||||
}
|
||||
|
||||
return {"valid": True, "error": []}
|
||||
|
||||
|
||||
def list_checker(param: str, model_output: list, possible_answer: list):
|
||||
# Convert the tuple to a list
|
||||
|
||||
standardize_model_output = list(model_output)
|
||||
|
||||
# If the element in the list is a string, we need to standardize it
|
||||
for i in range(len(standardize_model_output)):
|
||||
if type(standardize_model_output[i]) == str:
|
||||
standardize_model_output[i] = standardize_string(model_output[i])
|
||||
|
||||
standardize_possible_answer: Any = []
|
||||
# We also need to standardize the possible answers
|
||||
for i in range(len(possible_answer)):
|
||||
standardize_possible_answer.append([])
|
||||
for j in range(len(possible_answer[i])):
|
||||
if type(possible_answer[i][j]) == str:
|
||||
standardize_possible_answer[i].append(standardize_string(possible_answer[i][j]))
|
||||
else:
|
||||
standardize_possible_answer[i].append(possible_answer[i][j])
|
||||
|
||||
if standardize_model_output not in standardize_possible_answer:
|
||||
return {
|
||||
"valid": False,
|
||||
"error": [
|
||||
f"Invalid value for parameter {repr(param)}: {repr(model_output)}. Expected one of {possible_answer}."
|
||||
],
|
||||
"error_type": "value_error:list/tuple",
|
||||
}
|
||||
|
||||
return {"valid": True, "error": []}
|
||||
|
||||
|
||||
def dict_checker(param: str, model_output: dict, possible_answers: list):
|
||||
# This function works for simple dictionaries, but not dictionaries with nested dictionaries.
|
||||
# The current dataset only contains simple dictionaries, so this is sufficient.
|
||||
|
||||
result = {"valid": False, "error": [], "error_type": "dict_checker:unclear"}
|
||||
for i in range(len(possible_answers)):
|
||||
if possible_answers[i] == "":
|
||||
continue
|
||||
|
||||
result = {"valid": False, "error": [], "error_type": "dict_checker:unclear"}
|
||||
|
||||
flag = True
|
||||
|
||||
possible_answer = possible_answers[i]
|
||||
# possible_anwer is a single dictionary
|
||||
|
||||
for key, value in model_output.items():
|
||||
if key not in possible_answer:
|
||||
result["valid"] = False
|
||||
result["error"].append(f"Unexpected dict key parameter: '{key}'.") # type: ignore[attr-defined]
|
||||
result["error_type"] = "value_error:dict_key"
|
||||
flag = False
|
||||
break
|
||||
|
||||
standardize_value = value
|
||||
# If the value is a string, we need to standardize it
|
||||
if type(value) == str:
|
||||
standardize_value = standardize_string(value)
|
||||
|
||||
# We also need to standardize the possible answers if they are string
|
||||
standardize_possible_answer = []
|
||||
for i in range(len(possible_answer[key])):
|
||||
if type(possible_answer[key][i]) == str:
|
||||
standardize_possible_answer.append(standardize_string(possible_answer[key][i]))
|
||||
else:
|
||||
standardize_possible_answer.append(possible_answer[key][i])
|
||||
|
||||
if standardize_value not in standardize_possible_answer:
|
||||
result["valid"] = False
|
||||
result["error"].append( # type: ignore[attr-defined]
|
||||
f"Invalid value for parameter {repr(key)}: {repr(value)}. Expected one of {standardize_possible_answer}."
|
||||
)
|
||||
result["error_type"] = "value_error:dict_value"
|
||||
flag = False
|
||||
break
|
||||
|
||||
for key, value in possible_answer.items():
|
||||
if key not in model_output and "" not in value:
|
||||
result["valid"] = False
|
||||
result["error"].append(f"Missing dict key parameter: '{key}'.") # type: ignore[attr-defined]
|
||||
result["error_type"] = "value_error:dict_key"
|
||||
flag = False
|
||||
break
|
||||
|
||||
if flag:
|
||||
return {"valid": True, "error": []}
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def list_dict_checker(param: str, model_output: list, possible_answers: list):
|
||||
# This function takes in a list of dictionaries and checks if each dictionary is valid
|
||||
# The order of the dictionaries in the list must match the order of the possible answers
|
||||
|
||||
result = {"valid": False, "error": [], "error_type": "list_dict_checker:unclear"}
|
||||
|
||||
for answer_index in range(len(possible_answers)):
|
||||
flag = True # True means so far, all dictionaries are valid
|
||||
|
||||
# Only proceed if the number of dictionaries in the list matches the number of dictionaries in the possible answers
|
||||
if len(model_output) != len(possible_answers[answer_index]):
|
||||
result["valid"] = False
|
||||
result["error"] = ["Wrong number of dictionaries in the list."]
|
||||
result["error_type"] = "value_error:list_dict_count"
|
||||
flag = False
|
||||
continue
|
||||
|
||||
for dict_index in range(len(model_output)):
|
||||
result = dict_checker(
|
||||
param,
|
||||
model_output[dict_index],
|
||||
[possible_answers[answer_index][dict_index]],
|
||||
)
|
||||
if not result["valid"]:
|
||||
flag = False
|
||||
break
|
||||
if flag:
|
||||
return {"valid": True, "error": []}
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def simple_function_checker(
|
||||
func_description: dict,
|
||||
model_output: dict,
|
||||
possible_answer: dict,
|
||||
language: str,
|
||||
model_name: str,
|
||||
):
|
||||
possible_answer = list(possible_answer.values())[0]
|
||||
# Extract function name and parameters details
|
||||
func_name = func_description["name"]
|
||||
param_details = func_description["parameters"]["properties"]
|
||||
required_params = func_description["parameters"]["required"]
|
||||
|
||||
# Initialize a result dictionary
|
||||
result = {
|
||||
"valid": True,
|
||||
"error": [],
|
||||
"error_type": "simple_function_checker:unclear",
|
||||
}
|
||||
|
||||
# Check if function name matches
|
||||
if func_name not in model_output:
|
||||
result["valid"] = False
|
||||
result["error"].append( # type: ignore[attr-defined]
|
||||
f"Function name {repr(func_name)} not found in model output."
|
||||
)
|
||||
result["error_type"] = "simple_function_checker:wrong_func_name"
|
||||
return result
|
||||
|
||||
model_params = model_output[func_name]
|
||||
|
||||
# Check for required parameters in model output
|
||||
for param in required_params:
|
||||
if param not in model_params:
|
||||
result["valid"] = False
|
||||
result["error"].append(f"Missing required parameter: {repr(param)}.") # type: ignore[attr-defined]
|
||||
result["error_type"] = "simple_function_checker:missing_required"
|
||||
return result
|
||||
|
||||
# Validate types and values for each parameter in model output
|
||||
for param, value in model_params.items():
|
||||
if param not in param_details or param not in possible_answer:
|
||||
result["valid"] = False
|
||||
result["error"].append(f"Unexpected parameter: {repr(param)}.") # type: ignore[attr-defined]
|
||||
result["error_type"] = "simple_function_checker:unexpected_param"
|
||||
return result
|
||||
|
||||
full_param_details = param_details[param]
|
||||
expected_type_description = full_param_details["type"] # This is a string
|
||||
is_variable = False
|
||||
nested_type_converted = None
|
||||
|
||||
if language == "Java":
|
||||
from evals.utils.bfcl.java_type_converter import java_type_converter
|
||||
|
||||
expected_type_converted = JAVA_TYPE_CONVERSION[expected_type_description]
|
||||
|
||||
if expected_type_description in JAVA_TYPE_CONVERSION:
|
||||
if type(value) != str:
|
||||
result["valid"] = False
|
||||
result["error"].append( # type: ignore[attr-defined]
|
||||
f"Incorrect type for parameter {repr(param)}. Expected type String, got {type(value).__name__}. Parameter value: {repr(value)}."
|
||||
)
|
||||
result["error_type"] = "type_error:java"
|
||||
return result
|
||||
|
||||
if expected_type_description in NESTED_CONVERSION_TYPE_LIST:
|
||||
nested_type = param_details[param]["items"]["type"]
|
||||
nested_type_converted = JAVA_TYPE_CONVERSION[nested_type]
|
||||
value = java_type_converter(value, expected_type_description, nested_type)
|
||||
else:
|
||||
value = java_type_converter(value, expected_type_description)
|
||||
|
||||
elif language == "JavaScript":
|
||||
from evals.utils.bfcl.js_type_converter import js_type_converter
|
||||
|
||||
expected_type_converted = JS_TYPE_CONVERSION[expected_type_description]
|
||||
|
||||
if expected_type_description in JS_TYPE_CONVERSION:
|
||||
if type(value) != str:
|
||||
result["valid"] = False
|
||||
result["error"].append( # type: ignore[attr-defined]
|
||||
f"Incorrect type for parameter {repr(param)}. Expected type String, got {type(value).__name__}. Parameter value: {repr(value)}."
|
||||
)
|
||||
result["error_type"] = "type_error:js"
|
||||
return result
|
||||
|
||||
if expected_type_description in NESTED_CONVERSION_TYPE_LIST:
|
||||
nested_type = param_details[param]["items"]["type"]
|
||||
nested_type_converted = JS_TYPE_CONVERSION[nested_type]
|
||||
value = js_type_converter(value, expected_type_description, nested_type)
|
||||
else:
|
||||
value = js_type_converter(value, expected_type_description)
|
||||
|
||||
elif language == "Python":
|
||||
expected_type_converted = PYTHON_TYPE_MAPPING[expected_type_description]
|
||||
if expected_type_description in PYTHON_NESTED_TYPE_CHECK_LIST:
|
||||
nested_type = param_details[param]["items"]["type"]
|
||||
nested_type_converted = PYTHON_TYPE_MAPPING[nested_type]
|
||||
|
||||
# We convert all tuple value to list when the expected type is tuple.
|
||||
# The conversion is necessary because any tuple in the possible answer would become a list after being processed through json.dump() and json.load().
|
||||
# This does introduce some false positive (eg, when the model provides a list value instead of tuple). We hope to find a better solution in the future.
|
||||
if expected_type_description == "tuple" and type(value) == tuple:
|
||||
value = list(value)
|
||||
|
||||
# Allow python auto conversion from int to float
|
||||
if language == "Python" and expected_type_description == "float" and type(value) == int:
|
||||
value = float(value)
|
||||
|
||||
# Type checking
|
||||
# In fact, we only check for Python here.
|
||||
# Type check for other languages are handled by the type converter, and so their value (after conversion) is always correct.
|
||||
type_check_result = type_checker(
|
||||
param,
|
||||
value,
|
||||
possible_answer[param],
|
||||
expected_type_description,
|
||||
expected_type_converted,
|
||||
nested_type_converted,
|
||||
)
|
||||
is_variable = type_check_result["is_variable"]
|
||||
if not type_check_result["valid"]:
|
||||
return type_check_result
|
||||
|
||||
# It doesn't make sense to special handle dictionaries and list of dictionaries if the value is a variable.
|
||||
# We can just treat the variable as a string and use the normal flow.
|
||||
if not is_variable:
|
||||
# Special handle for dictionaries
|
||||
if expected_type_converted == dict:
|
||||
result = dict_checker(param, value, possible_answer[param])
|
||||
if not result["valid"]:
|
||||
return result
|
||||
continue
|
||||
|
||||
# Special handle for list of dictionaries
|
||||
elif expected_type_converted == list and nested_type_converted == dict:
|
||||
result = list_dict_checker(param, value, possible_answer[param])
|
||||
if not result["valid"]:
|
||||
return result
|
||||
continue
|
||||
|
||||
# Special handle for strings
|
||||
elif expected_type_converted == str:
|
||||
# We don't check for case sensitivity for string, as long as it's not a variable
|
||||
result = string_checker(param, value, possible_answer[param])
|
||||
if not result["valid"]:
|
||||
return result
|
||||
continue
|
||||
|
||||
elif expected_type_converted == list:
|
||||
result = list_checker(param, value, possible_answer[param])
|
||||
if not result["valid"]:
|
||||
return result
|
||||
continue
|
||||
|
||||
# Check if the value is within the possible answers
|
||||
if value not in possible_answer[param]:
|
||||
result["valid"] = False
|
||||
result["error"].append( # type: ignore[attr-defined]
|
||||
f"Invalid value for parameter {repr(param)}: {repr(value)}. Expected one of {possible_answer[param]}."
|
||||
)
|
||||
result["error_type"] = "value_error:others"
|
||||
return result
|
||||
|
||||
# Check for optional parameters not provided but allowed
|
||||
for param in possible_answer:
|
||||
if param not in model_params and "" not in possible_answer[param]:
|
||||
result["valid"] = False
|
||||
result["error"].append( # type: ignore[attr-defined]
|
||||
f"Optional parameter {repr(param)} not provided and not marked as optional."
|
||||
)
|
||||
result["error_type"] = "simple_function_checker:missing_optional"
|
||||
return result
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def parallel_function_checker_enforce_order(
|
||||
func_descriptions: list,
|
||||
model_output: list,
|
||||
possible_answers: dict,
|
||||
language: str,
|
||||
model_name: str,
|
||||
):
|
||||
if len(model_output) != len(possible_answers):
|
||||
return {
|
||||
"valid": False,
|
||||
"error": ["Wrong number of functions."],
|
||||
"error_type": "parallel_function_checker_enforce_order:wrong_count",
|
||||
}
|
||||
|
||||
func_name_list = list(possible_answers.keys())
|
||||
possible_answers_list = []
|
||||
|
||||
for key, value in possible_answers.items():
|
||||
possible_answers_list.append({key: value})
|
||||
|
||||
for i in range(len(possible_answers_list)):
|
||||
func_description = find_description(func_descriptions, func_name_list[i])
|
||||
|
||||
result = simple_function_checker(
|
||||
func_description,
|
||||
model_output[i],
|
||||
possible_answers_list[i],
|
||||
language,
|
||||
model_name,
|
||||
)
|
||||
if not result["valid"]:
|
||||
return result
|
||||
|
||||
return {"valid": True, "error": []}
|
||||
|
||||
|
||||
def parallel_function_checker_no_order(
|
||||
func_descriptions: list,
|
||||
model_output: list,
|
||||
possible_answers: list,
|
||||
language: str,
|
||||
model_name: str,
|
||||
):
|
||||
if len(model_output) != len(possible_answers):
|
||||
return {
|
||||
"valid": False,
|
||||
"error": ["Wrong number of functions."],
|
||||
"error_type": "parallel_function_checker_no_order:wrong_count",
|
||||
}
|
||||
|
||||
matched_indices = []
|
||||
|
||||
# We go throught the possible answers one by one, and eliminate the model output that matches the possible answer
|
||||
# It must be this way because we need ground truth to fetch the correct function description
|
||||
for i in range(len(possible_answers)):
|
||||
# possible_answers[i] is a dictionary with only one key
|
||||
func_name_expected = list(possible_answers[i].keys())[0]
|
||||
func_description = find_description(func_descriptions, func_name_expected)
|
||||
|
||||
all_errors = []
|
||||
|
||||
for index in range(len(model_output)):
|
||||
if index in matched_indices:
|
||||
continue
|
||||
|
||||
result = simple_function_checker(
|
||||
func_description,
|
||||
model_output[index],
|
||||
possible_answers[i],
|
||||
language,
|
||||
model_name,
|
||||
)
|
||||
|
||||
if result["valid"]:
|
||||
matched_indices.append(index)
|
||||
break
|
||||
else:
|
||||
all_errors.append(
|
||||
{
|
||||
f"Model Result Index {index}": {
|
||||
"sub_error": result["error"],
|
||||
"sub_error_type": result["error_type"],
|
||||
"model_output_item": model_output[index],
|
||||
"possible_answer_item": possible_answers[i],
|
||||
}
|
||||
}
|
||||
)
|
||||
|
||||
if not result["valid"]:
|
||||
considered_indices = [i for i in range(len(model_output)) if i not in matched_indices]
|
||||
all_errors.insert(
|
||||
0,
|
||||
f"Could not find a matching function among index {considered_indices} of model output for index {i} of possible answers.", # type: ignore[arg-type]
|
||||
)
|
||||
return {
|
||||
"valid": False,
|
||||
"error": all_errors,
|
||||
"error_type": "parallel_function_checker_no_order:cannot_find_match",
|
||||
}
|
||||
|
||||
return {"valid": True, "error": []}
|
||||
|
||||
|
||||
def multiple_function_checker(
|
||||
func_descriptions: list,
|
||||
model_output: list,
|
||||
possible_answers: list,
|
||||
language: str,
|
||||
model_name: str,
|
||||
):
|
||||
if len(model_output) != len(possible_answers):
|
||||
return {
|
||||
"valid": False,
|
||||
"error": ["Wrong number of functions."],
|
||||
"error_type": "multiple_function_checker:wrong_count",
|
||||
}
|
||||
|
||||
# possible_answers is a list of only one dictionary with only one key
|
||||
func_name_expected = list(possible_answers[0].keys())[0]
|
||||
func_description = find_description(func_descriptions, func_name_expected)
|
||||
return simple_function_checker(
|
||||
func_description,
|
||||
model_output[0],
|
||||
possible_answers[0],
|
||||
language,
|
||||
model_name,
|
||||
)
|
||||
|
||||
|
||||
def patten_matcher(exec_output, expected_result, function_call, is_sanity_check):
|
||||
result = {"valid": True, "error": [], "error_type": "executable_checker:unclear"}
|
||||
|
||||
if type(exec_output) != type(expected_result):
|
||||
return {
|
||||
"valid": False,
|
||||
"error": [
|
||||
f"Wrong execution result type for {repr(function_call)}. Expected type: {type(expected_result)}, but got: {type(exec_output)}."
|
||||
],
|
||||
"error_type": "executable_checker:wrong_result_type",
|
||||
"model_executed_output": exec_output,
|
||||
}
|
||||
if type(exec_output) == dict:
|
||||
# We loose the requirement for the sanity check as the expected result used in the sanity check might not be the most up-to-date one.
|
||||
# This happens when the key is a timestamp or a random number.
|
||||
if is_sanity_check:
|
||||
if len(exec_output) != len(expected_result):
|
||||
return {
|
||||
"valid": False,
|
||||
"error": [
|
||||
f"Wrong execution result pattern for {repr(function_call)}. Expect type Dict, but wrong number of elements in the output. Expected length: {len(expected_result)}, but got: {len(exec_output)}."
|
||||
],
|
||||
"error_type": "executable_checker:wrong_result_type:dict_length",
|
||||
"model_executed_output": exec_output,
|
||||
}
|
||||
else:
|
||||
return result
|
||||
|
||||
for key, value in expected_result.items():
|
||||
if key not in exec_output:
|
||||
return {
|
||||
"valid": False,
|
||||
"error": [
|
||||
f"Wrong execution result pattern for {repr(function_call)}. Expect type Dict, but key {repr(key)} not found in the model output."
|
||||
],
|
||||
"error_type": "executable_checker:wrong_result_type:dict_key_not_found",
|
||||
"model_executed_output": exec_output,
|
||||
}
|
||||
for key, value in exec_output.items():
|
||||
if key not in expected_result:
|
||||
return {
|
||||
"valid": False,
|
||||
"error": [
|
||||
f"Wrong execution result pattern for {repr(function_call)}. Expect type Dict, but key {repr(key)} not expected in the model output."
|
||||
],
|
||||
"error_type": "executable_checker:wrong_result_type:dict_extra_key",
|
||||
"model_executed_output": exec_output,
|
||||
}
|
||||
if type(exec_output) == list:
|
||||
if len(exec_output) != len(expected_result):
|
||||
return {
|
||||
"valid": False,
|
||||
"error": [
|
||||
f"Wrong execution result pattern for {repr(function_call)}. Expect type list, but wrong number of elements in the output. Expected length: {len(expected_result)}, but got: {len(exec_output)}."
|
||||
],
|
||||
"error_type": "executable_checker:wrong_result_type:list_length",
|
||||
"model_executed_output": exec_output,
|
||||
}
|
||||
return result
|
||||
|
||||
|
||||
#### Helper functions for Exec ####
|
||||
def executable_checker_simple(
|
||||
function_call: str,
|
||||
expected_result,
|
||||
expected_result_type: str,
|
||||
is_sanity_check=False,
|
||||
):
|
||||
result = {"valid": True, "error": [], "error_type": "executable_checker:unclear"}
|
||||
|
||||
exec_dict: Any = {}
|
||||
|
||||
try:
|
||||
exec(
|
||||
"from executable_python_function import *" + "\nresult=" + function_call,
|
||||
exec_dict,
|
||||
)
|
||||
exec_output = exec_dict["result"]
|
||||
except NoAPIKeyError as e:
|
||||
raise e
|
||||
except Exception as e:
|
||||
result["valid"] = False
|
||||
result["error"].append( # type: ignore[attr-defined]
|
||||
f"Error in execution: {repr(function_call)}. Error: {str(e)}"
|
||||
)
|
||||
result["error_type"] = "executable_checker:execution_error"
|
||||
return result
|
||||
|
||||
# We need to special handle the case where the execution result is a tuple and convert it to a list
|
||||
# Because when json is stored, the tuple is converted to a list, and so the expected result is a list when loaded from json
|
||||
if isinstance(exec_output, tuple):
|
||||
exec_output = list(exec_output)
|
||||
|
||||
if expected_result_type == "exact_match":
|
||||
if exec_output != expected_result:
|
||||
result["valid"] = False
|
||||
result["error"].append( # type: ignore[attr-defined]
|
||||
f"Wrong execution result for {repr(function_call)}. Expected: {expected_result}, but got: {exec_output}."
|
||||
)
|
||||
result["error_type"] = "executable_checker:wrong_result"
|
||||
result["model_executed_output"] = exec_output
|
||||
return result
|
||||
|
||||
elif expected_result_type == "real_time_match":
|
||||
# Allow for 5% difference
|
||||
if (type(expected_result) == float or type(expected_result) == int) and (
|
||||
type(exec_output) == float or type(exec_output) == int
|
||||
):
|
||||
if not (
|
||||
expected_result * (1 - REAL_TIME_MATCH_ALLOWED_DIFFERENCE)
|
||||
<= exec_output
|
||||
<= expected_result * (1 + REAL_TIME_MATCH_ALLOWED_DIFFERENCE)
|
||||
):
|
||||
result["valid"] = False
|
||||
result["error"].append( # type: ignore[attr-defined]
|
||||
f"Wrong execution result for {repr(function_call)}. Expected: {expected_result}, but got: {exec_output}. {REAL_TIME_MATCH_ALLOWED_DIFFERENCE * 100}% difference allowed."
|
||||
)
|
||||
result["error_type"] = "executable_checker:wrong_result_real_time"
|
||||
result["model_executed_output"] = exec_output
|
||||
return result
|
||||
else:
|
||||
result["valid"] = False
|
||||
result["error"].append( # type: ignore[attr-defined]
|
||||
f"Wrong execution result for {repr(function_call)}. Expected: {expected_result}, but got: {exec_output}. Type needs to be float or int for real time match criteria."
|
||||
)
|
||||
result["error_type"] = "executable_checker:wrong_result_real_time"
|
||||
result["model_executed_output"] = exec_output
|
||||
return result
|
||||
|
||||
else:
|
||||
# structural match
|
||||
pattern_match_result = patten_matcher(exec_output, expected_result, function_call, is_sanity_check)
|
||||
if not pattern_match_result["valid"]:
|
||||
return pattern_match_result
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def executable_checker_parallel_no_order(
|
||||
decoded_result: list, expected_exec_result: list, expected_exec_result_type: list
|
||||
):
|
||||
if len(decoded_result) != len(expected_exec_result):
|
||||
return {
|
||||
"valid": False,
|
||||
"error": [
|
||||
f"Wrong number of functions provided. Expected {len(expected_exec_result)}, but got {len(decoded_result)}."
|
||||
],
|
||||
"error_type": "value_error:exec_result_count",
|
||||
}
|
||||
|
||||
matched_indices = []
|
||||
for i in range(len(expected_exec_result)):
|
||||
all_errors = []
|
||||
for index in range(len(decoded_result)):
|
||||
if index in matched_indices:
|
||||
continue
|
||||
|
||||
result = executable_checker_simple(
|
||||
decoded_result[index],
|
||||
expected_exec_result[i],
|
||||
expected_exec_result_type[i],
|
||||
False,
|
||||
)
|
||||
|
||||
if result["valid"]:
|
||||
matched_indices.append(index)
|
||||
break
|
||||
else:
|
||||
all_errors.append(
|
||||
{
|
||||
f"Model Result Index {index}": {
|
||||
"sub_error": result["error"],
|
||||
"sub_error_type": result["error_type"],
|
||||
"model_executed_output": (
|
||||
result["model_executed_output"] if "model_executed_output" in result else None
|
||||
),
|
||||
}
|
||||
}
|
||||
)
|
||||
|
||||
if not result["valid"]:
|
||||
considered_indices = [i for i in range(len(decoded_result)) if i not in matched_indices]
|
||||
all_errors.insert(
|
||||
0,
|
||||
f"Could not find a matching function among index {considered_indices} of model output for index {i} of possible answers.", # type: ignore[arg-type]
|
||||
)
|
||||
return {
|
||||
"valid": False,
|
||||
"error": all_errors,
|
||||
"error_type": "executable_checker:cannot_find_match",
|
||||
}
|
||||
|
||||
return {"valid": True, "error": [], "error_type": "executable_checker:unclear"}
|
||||
|
||||
|
||||
#### Main function ####
|
||||
def executable_checker_rest(func_call, idx):
|
||||
# Move this here for now to avoid needing to read this file / fix paths to be relative to dataset_dir. Fix when it's actually needed / used.
|
||||
EVAL_GROUND_TRUTH_PATH = "/mnt/wsfuse/fair_llm_v2/datasets/eval/bfcl/rest-eval-response_v5.jsonl" # Ground truth file for v5 for rest execution
|
||||
with open(EVAL_GROUND_TRUTH_PATH, "r") as f:
|
||||
EVAL_GROUND_TRUTH = f.readlines()
|
||||
if "https://geocode.maps.co" in func_call:
|
||||
time.sleep(2)
|
||||
if "requests_get" in func_call:
|
||||
func_call = func_call.replace("requests_get", "requests.get")
|
||||
try:
|
||||
response = eval(func_call)
|
||||
except Exception as e:
|
||||
return {
|
||||
"valid": False,
|
||||
"error": [f"Execution failed. {str(e)}"],
|
||||
"error_type": "executable_checker_rest:execution_error",
|
||||
}
|
||||
|
||||
try:
|
||||
if response.status_code == 200:
|
||||
eval_GT_json = json.loads(EVAL_GROUND_TRUTH[idx])
|
||||
try:
|
||||
if isinstance(eval_GT_json, dict):
|
||||
if isinstance(response.json(), dict):
|
||||
if set(eval_GT_json.keys()) == set(response.json().keys()):
|
||||
return {"valid": True, "error": [], "error_type": ""}
|
||||
return {
|
||||
"valid": False,
|
||||
"error": ["Key inconsistency"],
|
||||
"error_type": "executable_checker_rest:wrong_key",
|
||||
}
|
||||
return {
|
||||
"valid": False,
|
||||
"error": [f"Expected dictionary, but got {type(response.json())}"],
|
||||
"error_type": "executable_checker_rest:wrong_type",
|
||||
}
|
||||
|
||||
elif isinstance(eval_GT_json, list):
|
||||
if isinstance(response.json(), list):
|
||||
if len(eval_GT_json) != len(response.json()):
|
||||
return {
|
||||
"valid": False,
|
||||
"error": [f"Response list length inconsistency."],
|
||||
"error_type": "value_error:exec_result_rest_count",
|
||||
}
|
||||
|
||||
else:
|
||||
for i in range(len(eval_GT_json)):
|
||||
if set(eval_GT_json[i].keys()) != set(response.json()[i].keys()):
|
||||
return {
|
||||
"valid": False,
|
||||
"error": [f"Key inconsistency"],
|
||||
"error_type": "executable_checker_rest:wrong_key",
|
||||
}
|
||||
|
||||
return {"valid": True, "error": []}
|
||||
else:
|
||||
return {
|
||||
"valid": False,
|
||||
"error": [f"Expected list, but got {type(response.json())}"],
|
||||
"error_type": "executable_checker_rest:wrong_type",
|
||||
}
|
||||
return {
|
||||
"valid": False,
|
||||
"error": [f"Expected dict or list, but got {type(response.json())}"],
|
||||
"error_type": "executable_checker_rest:wrong_type",
|
||||
}
|
||||
except Exception as e:
|
||||
return {
|
||||
"valid": False,
|
||||
"error": [
|
||||
f"Error in execution and type checking. Status code: {response.status_code}. Error: {str(e)}"
|
||||
],
|
||||
"error_type": "executable_checker_rest:response_format_error",
|
||||
}
|
||||
else:
|
||||
return {
|
||||
"valid": False,
|
||||
"error": [f"Execution result status code is not 200, got {response.status_code}"],
|
||||
"error_type": "executable_checker_rest:wrong_status_code",
|
||||
}
|
||||
except Exception as e:
|
||||
return {
|
||||
"valid": False,
|
||||
"error": [f"Cannot get status code of the response. Error: {str(e)}"],
|
||||
"error_type": "executable_checker_rest:cannot_get_status_code",
|
||||
}
|
||||
|
||||
|
||||
def ast_checker(func_description, model_output, possible_answer, language, test_category, model_name):
|
||||
if "parallel" in test_category:
|
||||
return parallel_function_checker_no_order(func_description, model_output, possible_answer, language, model_name)
|
||||
|
||||
elif "multiple" in test_category:
|
||||
return multiple_function_checker(func_description, model_output, possible_answer, language, model_name)
|
||||
|
||||
else:
|
||||
if len(model_output) != 1:
|
||||
return {
|
||||
"valid": False,
|
||||
"error": ["Wrong number of functions."],
|
||||
"error_type": "simple_function_checker:wrong_count",
|
||||
}
|
||||
|
||||
return simple_function_checker(
|
||||
func_description[0],
|
||||
model_output[0],
|
||||
possible_answer[0],
|
||||
language,
|
||||
model_name,
|
||||
)
|
||||
|
||||
|
||||
def exec_checker(decoded_result: list, func_description: dict, test_category: str):
|
||||
if "multiple" in test_category or "parallel" in test_category:
|
||||
return executable_checker_parallel_no_order(
|
||||
decoded_result,
|
||||
func_description["execution_result"],
|
||||
func_description["execution_result_type"],
|
||||
)
|
||||
|
||||
else:
|
||||
if len(decoded_result) != 1:
|
||||
return {
|
||||
"valid": False,
|
||||
"error": ["Wrong number of functions."],
|
||||
"error_type": "simple_exec_checker:wrong_count",
|
||||
}
|
||||
return executable_checker_simple(
|
||||
decoded_result[0],
|
||||
func_description["execution_result"][0],
|
||||
func_description["execution_result_type"][0],
|
||||
False,
|
||||
)
|
||||
|
||||
|
||||
def is_empty_output(decoded_output):
|
||||
# This function is a patch to the ast decoder for relevance detection
|
||||
# Sometimes the ast decoder will parse successfully, but the input doens't really have a function call
|
||||
# [], [{}], and anything that is not in function calling format is considered empty (and thus should be marked as correct)
|
||||
if not is_function_calling_format_output(decoded_output):
|
||||
return True
|
||||
if len(decoded_output) == 0:
|
||||
return True
|
||||
if len(decoded_output) == 1 and len(decoded_output[0]) == 0:
|
||||
return True
|
||||
|
||||
|
||||
def is_function_calling_format_output(decoded_output):
|
||||
# Ensure the output is a list of dictionaries
|
||||
if type(decoded_output) == list:
|
||||
for item in decoded_output:
|
||||
if type(item) != dict:
|
||||
return False
|
||||
return True
|
||||
return False
|
||||
|
|
@ -0,0 +1,40 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
"""
|
||||
Tree-sitter changes its API with unfortunate frequency. Modules that need it should
|
||||
import it from here so that we can centrally manage things as necessary.
|
||||
"""
|
||||
|
||||
# These currently work with tree-sitter 0.23.0
|
||||
# NOTE: Don't import tree-sitter or any of the language modules in the main module
|
||||
# because not all environments have them. Import lazily inside functions where needed.
|
||||
|
||||
import importlib
|
||||
import typing
|
||||
|
||||
if typing.TYPE_CHECKING:
|
||||
import tree_sitter
|
||||
|
||||
|
||||
def get_language(language: str) -> "tree_sitter.Language":
|
||||
import tree_sitter
|
||||
|
||||
language_module_name = f"tree_sitter_{language}"
|
||||
try:
|
||||
language_module = importlib.import_module(language_module_name)
|
||||
except ModuleNotFoundError as exc:
|
||||
raise ValueError(
|
||||
f"Language {language} is not found. Please install the tree-sitter-{language} package."
|
||||
) from exc
|
||||
return tree_sitter.Language(language_module.language())
|
||||
|
||||
|
||||
def get_parser(language: str, **kwargs) -> "tree_sitter.Parser":
|
||||
import tree_sitter
|
||||
|
||||
lang = get_language(language)
|
||||
return tree_sitter.Parser(lang, **kwargs)
|
||||
3319
llama_stack/providers/inline/scoring/basic/utils/ifeval_utils.py
Normal file
3319
llama_stack/providers/inline/scoring/basic/utils/ifeval_utils.py
Normal file
File diff suppressed because it is too large
Load diff
330
llama_stack/providers/inline/scoring/basic/utils/math_utils.py
Normal file
330
llama_stack/providers/inline/scoring/basic/utils/math_utils.py
Normal file
|
|
@ -0,0 +1,330 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import re
|
||||
from typing import Sequence
|
||||
|
||||
from llama_stack.providers.utils.scoring.basic_scoring_utils import time_limit
|
||||
|
||||
# from minerva
|
||||
SUBSTITUTIONS = [
|
||||
("an ", ""),
|
||||
("a ", ""),
|
||||
(".$", "$"),
|
||||
("\\$", ""),
|
||||
(r"\ ", ""),
|
||||
(" ", ""),
|
||||
("mbox", "text"),
|
||||
(",\\text{and}", ","),
|
||||
("\\text{and}", ","),
|
||||
("\\text{m}", "\\text{}"),
|
||||
]
|
||||
|
||||
REMOVED_EXPRESSIONS = [
|
||||
"square",
|
||||
"ways",
|
||||
"integers",
|
||||
"dollars",
|
||||
"mph",
|
||||
"inches",
|
||||
"ft",
|
||||
"hours",
|
||||
"km",
|
||||
"units",
|
||||
"\\ldots",
|
||||
"sue",
|
||||
"points",
|
||||
"feet",
|
||||
"minutes",
|
||||
"digits",
|
||||
"cents",
|
||||
"degrees",
|
||||
"cm",
|
||||
"gm",
|
||||
"pounds",
|
||||
"meters",
|
||||
"meals",
|
||||
"edges",
|
||||
"students",
|
||||
"childrentickets",
|
||||
"multiples",
|
||||
"\\text{s}",
|
||||
"\\text{.}",
|
||||
"\\text{\ns}",
|
||||
"\\text{}^2",
|
||||
"\\text{}^3",
|
||||
"\\text{\n}",
|
||||
"\\text{}",
|
||||
r"\mathrm{th}",
|
||||
r"^\circ",
|
||||
r"^{\circ}",
|
||||
r"\;",
|
||||
r",\!",
|
||||
"{,}",
|
||||
'"',
|
||||
"\\dots",
|
||||
]
|
||||
|
||||
|
||||
def try_evaluate_frac(expression: str, fmt: str = "0.2e") -> str:
|
||||
if isinstance(expression, float):
|
||||
return expression
|
||||
new_expression = f"{expression}"
|
||||
regex = re.compile(r"\\frac{([^}]+)}{([^}]+)}")
|
||||
for match in re.finditer(regex, expression):
|
||||
try:
|
||||
value = float(match.group(1)) / float(match.group(2))
|
||||
new_expression = new_expression.replace(
|
||||
match.group(),
|
||||
f"{{value:{fmt}}}".format(value=value),
|
||||
1,
|
||||
)
|
||||
except Exception:
|
||||
continue
|
||||
return new_expression
|
||||
|
||||
|
||||
def try_evaluate_latex(expression: str, fmt: str = ".2e") -> str:
|
||||
try:
|
||||
with time_limit(seconds=5):
|
||||
from sympy.parsing.latex import parse_latex
|
||||
|
||||
value = parse_latex(expression).evalf() # type: ignore
|
||||
return f"{{value:{fmt}}}".format(value=value)
|
||||
except Exception:
|
||||
return expression
|
||||
|
||||
|
||||
def first_answer(text: str, markers: Sequence[str] = ("Q:", "A:")) -> str:
|
||||
for marker in markers:
|
||||
text = text.split(marker)[0]
|
||||
return text
|
||||
|
||||
|
||||
def extract_result_from_boxed(answer: str) -> str:
|
||||
box_start = "\\boxed"
|
||||
# format is `\\boxed <value>$` or `\\boxed{<value>}`, with potential white spaces framing `<value>`
|
||||
start = answer.rfind(box_start)
|
||||
if start < 0:
|
||||
return ""
|
||||
answer = answer[start + len(box_start) :].strip()
|
||||
ends_with_curly = answer.startswith("{")
|
||||
i = 0
|
||||
open_braces = 0
|
||||
while i < len(answer):
|
||||
if answer[i] == "{":
|
||||
open_braces += 1
|
||||
elif answer[i] == "}":
|
||||
open_braces -= 1
|
||||
if open_braces == 0:
|
||||
if ends_with_curly:
|
||||
answer = answer[: i + 1].strip()
|
||||
break
|
||||
elif answer[i] == "$":
|
||||
answer = answer[:i].strip()
|
||||
break
|
||||
i += 1
|
||||
else:
|
||||
return ""
|
||||
# remove extra curly braces
|
||||
while True:
|
||||
if answer.startswith("{") and answer.endswith("}"):
|
||||
answer = answer[1:-1].strip()
|
||||
else:
|
||||
break
|
||||
return answer
|
||||
|
||||
|
||||
# from minerva paper + _normalise_result from xavierm
|
||||
def normalize_final_answer(final_answer: str, regex_pattern: str, match_first: bool = True) -> str:
|
||||
"""Extract and normalize a final answer to a quantitative reasoning question."""
|
||||
match = re.findall(regex_pattern, final_answer)
|
||||
extraction: str
|
||||
if len(match) > 0:
|
||||
if match_first:
|
||||
extraction = match[0]
|
||||
else:
|
||||
extraction = match[-1]
|
||||
else:
|
||||
extraction = extract_result_from_boxed(final_answer)
|
||||
|
||||
if len(extraction) == 0:
|
||||
return final_answer
|
||||
else:
|
||||
final_answer = extraction
|
||||
final_answer = final_answer.split("=")[-1]
|
||||
for before, after in SUBSTITUTIONS:
|
||||
final_answer = final_answer.replace(before, after)
|
||||
for expr in REMOVED_EXPRESSIONS:
|
||||
final_answer = final_answer.replace(expr, "")
|
||||
# Extract answer that is in LaTeX math, is bold,
|
||||
# is surrounded by a box, etc.
|
||||
final_answer = re.sub(r"(.*?)(\$)(.*?)(\$)(.*)", "$\\3$", final_answer)
|
||||
final_answer = re.sub(r"(\\text\{)(.*?)(\})", "\\2", final_answer)
|
||||
final_answer = re.sub(r"(\\textbf\{)(.*?)(\})", "\\2", final_answer)
|
||||
final_answer = re.sub(r"(\\overline\{)(.*?)(\})", "\\2", final_answer)
|
||||
final_answer = re.sub(r"(\\boxed\{)(.*)(\})", "\\2", final_answer)
|
||||
# Normalize shorthand TeX:
|
||||
# \fracab -> \frac{a}{b}
|
||||
# \frac{abc}{bef} -> \frac{abc}{bef}
|
||||
# \fracabc -> \frac{a}{b}c
|
||||
# \sqrta -> \sqrt{a}
|
||||
# \sqrtab -> sqrt{a}b
|
||||
final_answer = re.sub(r"(frac)([^{])(.)", "frac{\\2}{\\3}", final_answer)
|
||||
final_answer = re.sub(r"(sqrt)([^{])", "sqrt{\\2}", final_answer)
|
||||
final_answer = final_answer.replace("$", "")
|
||||
# Normalize 100,000 -> 100000
|
||||
if final_answer.replace(",", "").isdigit():
|
||||
final_answer = final_answer.replace(",", "")
|
||||
# If the final answer is a single letter in parentheses, remove the parentheses
|
||||
# Example: (a) -> a (but not (ab) -> ab)
|
||||
if re.match(r"\([a-zA-Z]\)", final_answer):
|
||||
final_answer = final_answer[1]
|
||||
return _normalise_result(final_answer)
|
||||
|
||||
|
||||
def _normalise_result(string: str) -> str:
|
||||
# linebreaks
|
||||
string = string.replace("\n", "")
|
||||
|
||||
# remove inverse spaces
|
||||
string = string.replace("\\!", "")
|
||||
|
||||
# replace \\ with \
|
||||
string = string.replace("\\\\", "\\")
|
||||
|
||||
# replace tfrac and dfrac with frac
|
||||
string = string.replace("cfrac", "frac")
|
||||
string = string.replace("tfrac", "frac")
|
||||
string = string.replace("dfrac", "frac")
|
||||
|
||||
# remove \left and \right
|
||||
string = string.replace("\\left", "")
|
||||
string = string.replace("\\le", "")
|
||||
string = string.replace("\\right", "")
|
||||
|
||||
# Remove circ (degrees)
|
||||
string = string.replace("^{\\circ}", "")
|
||||
string = string.replace("^\\circ", "")
|
||||
|
||||
# remove dollar signs
|
||||
string = string.replace("\\$", "")
|
||||
|
||||
# remove units (on the right)
|
||||
string = _remove_right_units(string)
|
||||
|
||||
# remove percentage
|
||||
string = string.replace("\\%", "")
|
||||
string = string.replace(r"\%", "")
|
||||
|
||||
# " 0." equivalent to " ." and "{0." equivalent to "{." Alternatively, add "0" if "." is the start of the string
|
||||
string = string.replace(" .", " 0.")
|
||||
string = string.replace("{.", "{0.")
|
||||
# if empty, return empty string
|
||||
if len(string) == 0:
|
||||
return string
|
||||
if string[0] == ".":
|
||||
string = "0" + string
|
||||
|
||||
# to consider: get rid of e.g. "k = " or "q = " at beginning
|
||||
string = string.split("=")[-1]
|
||||
|
||||
# fix sqrt3 --> sqrt{3}
|
||||
string = _fix_sqrt(string)
|
||||
|
||||
# remove spaces
|
||||
string = string.replace(" ", "")
|
||||
|
||||
# \frac1b or \frac12 --> \frac{1}{b} and \frac{1}{2}, etc. Even works with \frac1{72} (but not \frac{72}1). Also does a/b --> \\frac{a}{b}
|
||||
string = _fix_fracs(string)
|
||||
|
||||
# manually change 0.5 --> \frac{1}{2}
|
||||
if string == "0.5":
|
||||
string = "\\frac{1}{2}"
|
||||
|
||||
# NOTE: X/Y changed to \frac{X}{Y} in dataset, but in simple cases fix in case the model output is X/Y
|
||||
string = _fix_a_slash_b(string)
|
||||
|
||||
return string
|
||||
|
||||
|
||||
def _remove_right_units(string: str) -> str:
|
||||
# "\\text{ " only ever occurs (at least in the val set) when describing units
|
||||
try:
|
||||
if "\\text{ " in string:
|
||||
splits = string.split("\\text{ ")
|
||||
assert len(splits) == 2
|
||||
return splits[0]
|
||||
else:
|
||||
return string
|
||||
except AssertionError:
|
||||
return string
|
||||
|
||||
|
||||
def _fix_sqrt(string: str) -> str:
|
||||
if "\\sqrt" not in string:
|
||||
return string
|
||||
splits = string.split("\\sqrt")
|
||||
new_string = splits[0]
|
||||
for split in splits[1:]:
|
||||
if len(split) == 0:
|
||||
return string
|
||||
if split[0] != "{":
|
||||
a = split[0]
|
||||
new_substr = "\\sqrt{" + a + "}" + split[1:]
|
||||
else:
|
||||
new_substr = "\\sqrt" + split
|
||||
new_string += new_substr
|
||||
return new_string
|
||||
|
||||
|
||||
def _fix_fracs(string: str) -> str:
|
||||
substrs = string.split("\\frac")
|
||||
new_str = substrs[0]
|
||||
if len(substrs) > 1:
|
||||
substrs = substrs[1:]
|
||||
for substr in substrs:
|
||||
new_str += "\\frac"
|
||||
if len(substr) == 0:
|
||||
return string
|
||||
if substr[0] == "{":
|
||||
new_str += substr
|
||||
else:
|
||||
try:
|
||||
assert len(substr) >= 2
|
||||
except AssertionError:
|
||||
return string
|
||||
a = substr[0]
|
||||
b = substr[1]
|
||||
if b != "{":
|
||||
if len(substr) > 2:
|
||||
post_substr = substr[2:]
|
||||
new_str += "{" + a + "}{" + b + "}" + post_substr
|
||||
else:
|
||||
new_str += "{" + a + "}{" + b + "}"
|
||||
else:
|
||||
if len(substr) > 2:
|
||||
post_substr = substr[2:]
|
||||
new_str += "{" + a + "}" + b + post_substr
|
||||
else:
|
||||
new_str += "{" + a + "}" + b
|
||||
string = new_str
|
||||
return string
|
||||
|
||||
|
||||
def _fix_a_slash_b(string: str) -> str:
|
||||
if len(string.split("/")) != 2:
|
||||
return string
|
||||
a = string.split("/")[0]
|
||||
b = string.split("/")[1]
|
||||
try:
|
||||
ia = int(a)
|
||||
ib = int(b)
|
||||
assert string == "{}/{}".format(ia, ib)
|
||||
new_string = "\\frac{" + str(ia) + "}{" + str(ib) + "}"
|
||||
return new_string
|
||||
except (ValueError, AssertionError):
|
||||
return string
|
||||
|
|
@ -3,11 +3,11 @@
|
|||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
from typing import Dict
|
||||
from typing import Any, Dict
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from llama_stack.distribution.datatypes import Api, ProviderSpec
|
||||
from llama_stack.distribution.datatypes import Api
|
||||
|
||||
from .config import BraintrustScoringConfig
|
||||
|
||||
|
|
@ -18,7 +18,7 @@ class BraintrustProviderDataValidator(BaseModel):
|
|||
|
||||
async def get_provider_impl(
|
||||
config: BraintrustScoringConfig,
|
||||
deps: Dict[Api, ProviderSpec],
|
||||
deps: Dict[Api, Any],
|
||||
):
|
||||
from .braintrust import BraintrustScoringImpl
|
||||
|
||||
|
|
|
|||
|
|
@ -167,11 +167,11 @@ class BraintrustScoringImpl(
|
|||
dataset_def = await self.datasets_api.get_dataset(dataset_id=dataset_id)
|
||||
validate_dataset_schema(dataset_def.dataset_schema, get_valid_schemas(Api.scoring.value))
|
||||
|
||||
all_rows = await self.datasetio_api.get_rows_paginated(
|
||||
all_rows = await self.datasetio_api.iterrows(
|
||||
dataset_id=dataset_id,
|
||||
rows_in_page=-1,
|
||||
limit=-1,
|
||||
)
|
||||
res = await self.score(input_rows=all_rows.rows, scoring_functions=scoring_functions)
|
||||
res = await self.score(input_rows=all_rows.data, scoring_functions=scoring_functions)
|
||||
if save_results_dataset:
|
||||
# TODO: persist and register dataset on to server for reading
|
||||
# self.datasets_api.register_dataset()
|
||||
|
|
|
|||
|
|
@ -3,16 +3,16 @@
|
|||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
from typing import Dict
|
||||
from typing import Any, Dict
|
||||
|
||||
from llama_stack.distribution.datatypes import Api, ProviderSpec
|
||||
from llama_stack.distribution.datatypes import Api
|
||||
|
||||
from .config import LlmAsJudgeScoringConfig
|
||||
|
||||
|
||||
async def get_provider_impl(
|
||||
config: LlmAsJudgeScoringConfig,
|
||||
deps: Dict[Api, ProviderSpec],
|
||||
deps: Dict[Api, Any],
|
||||
):
|
||||
from .scoring import LlmAsJudgeScoringImpl
|
||||
|
||||
|
|
|
|||
|
|
@ -3,7 +3,12 @@
|
|||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
from typing import Any, Dict
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class LlmAsJudgeScoringConfig(BaseModel): ...
|
||||
class LlmAsJudgeScoringConfig(BaseModel):
|
||||
@classmethod
|
||||
def sample_run_config(cls, __distro_dir__: str, **kwargs: Any) -> Dict[str, Any]:
|
||||
return {}
|
||||
|
|
|
|||
|
|
@ -72,12 +72,12 @@ class LlmAsJudgeScoringImpl(
|
|||
dataset_def = await self.datasets_api.get_dataset(dataset_id=dataset_id)
|
||||
validate_dataset_schema(dataset_def.dataset_schema, get_valid_schemas(Api.scoring.value))
|
||||
|
||||
all_rows = await self.datasetio_api.get_rows_paginated(
|
||||
all_rows = await self.datasetio_api.iterrows(
|
||||
dataset_id=dataset_id,
|
||||
rows_in_page=-1,
|
||||
limit=-1,
|
||||
)
|
||||
res = await self.score(
|
||||
input_rows=all_rows.rows,
|
||||
input_rows=all_rows.data,
|
||||
scoring_functions=scoring_functions,
|
||||
)
|
||||
if save_results_dataset:
|
||||
|
|
|
|||
|
|
@ -5,7 +5,7 @@
|
|||
# the root directory of this source tree.
|
||||
|
||||
import json
|
||||
from datetime import datetime
|
||||
from datetime import datetime, timezone
|
||||
|
||||
from opentelemetry.sdk.trace import ReadableSpan
|
||||
from opentelemetry.sdk.trace.export import SpanProcessor
|
||||
|
|
@ -34,7 +34,7 @@ class ConsoleSpanProcessor(SpanProcessor):
|
|||
if span.attributes and span.attributes.get("__autotraced__"):
|
||||
return
|
||||
|
||||
timestamp = datetime.utcfromtimestamp(span.start_time / 1e9).strftime("%H:%M:%S.%f")[:-3]
|
||||
timestamp = datetime.fromtimestamp(span.start_time / 1e9, tz=timezone.utc).strftime("%H:%M:%S.%f")[:-3]
|
||||
|
||||
print(
|
||||
f"{COLORS['dim']}{timestamp}{COLORS['reset']} "
|
||||
|
|
@ -46,7 +46,7 @@ class ConsoleSpanProcessor(SpanProcessor):
|
|||
if span.attributes and span.attributes.get("__autotraced__"):
|
||||
return
|
||||
|
||||
timestamp = datetime.utcfromtimestamp(span.end_time / 1e9).strftime("%H:%M:%S.%f")[:-3]
|
||||
timestamp = datetime.fromtimestamp(span.end_time / 1e9, tz=timezone.utc).strftime("%H:%M:%S.%f")[:-3]
|
||||
|
||||
span_context = (
|
||||
f"{COLORS['dim']}{timestamp}{COLORS['reset']} "
|
||||
|
|
@ -74,7 +74,7 @@ class ConsoleSpanProcessor(SpanProcessor):
|
|||
print(f" {COLORS['dim']}{key}: {str_value}{COLORS['reset']}")
|
||||
|
||||
for event in span.events:
|
||||
event_time = datetime.utcfromtimestamp(event.timestamp / 1e9).strftime("%H:%M:%S.%f")[:-3]
|
||||
event_time = datetime.fromtimestamp(event.timestamp / 1e9, tz=timezone.utc).strftime("%H:%M:%S.%f")[:-3]
|
||||
|
||||
severity = event.attributes.get("severity", "info")
|
||||
message = event.attributes.get("message", event.name)
|
||||
|
|
|
|||
|
|
@ -8,7 +8,7 @@ import json
|
|||
import os
|
||||
import sqlite3
|
||||
import threading
|
||||
from datetime import datetime
|
||||
from datetime import datetime, timezone
|
||||
|
||||
from opentelemetry.sdk.trace import SpanProcessor
|
||||
from opentelemetry.trace import Span
|
||||
|
|
@ -124,8 +124,8 @@ class SQLiteSpanProcessor(SpanProcessor):
|
|||
trace_id,
|
||||
service_name,
|
||||
(span_id if not parent_span_id else None),
|
||||
datetime.fromtimestamp(span.start_time / 1e9).isoformat(),
|
||||
datetime.fromtimestamp(span.end_time / 1e9).isoformat(),
|
||||
datetime.fromtimestamp(span.start_time / 1e9, timezone.utc).isoformat(),
|
||||
datetime.fromtimestamp(span.end_time / 1e9, timezone.utc).isoformat(),
|
||||
),
|
||||
)
|
||||
|
||||
|
|
@ -143,8 +143,8 @@ class SQLiteSpanProcessor(SpanProcessor):
|
|||
trace_id,
|
||||
parent_span_id,
|
||||
span.name,
|
||||
datetime.fromtimestamp(span.start_time / 1e9).isoformat(),
|
||||
datetime.fromtimestamp(span.end_time / 1e9).isoformat(),
|
||||
datetime.fromtimestamp(span.start_time / 1e9, timezone.utc).isoformat(),
|
||||
datetime.fromtimestamp(span.end_time / 1e9, timezone.utc).isoformat(),
|
||||
json.dumps(dict(span.attributes)),
|
||||
span.status.status_code.name,
|
||||
span.kind.name,
|
||||
|
|
@ -161,7 +161,7 @@ class SQLiteSpanProcessor(SpanProcessor):
|
|||
(
|
||||
span_id,
|
||||
event.name,
|
||||
datetime.fromtimestamp(event.timestamp / 1e9).isoformat(),
|
||||
datetime.fromtimestamp(event.timestamp / 1e9, timezone.utc).isoformat(),
|
||||
json.dumps(dict(event.attributes)),
|
||||
),
|
||||
)
|
||||
|
|
|
|||
|
|
@ -73,6 +73,7 @@ class TelemetryAdapter(TelemetryDatasetMixin, Telemetry):
|
|||
def __init__(self, config: TelemetryConfig, deps: Dict[str, Any]) -> None:
|
||||
self.config = config
|
||||
self.datasetio_api = deps.get(Api.datasetio)
|
||||
self.meter = None
|
||||
|
||||
resource = Resource.create(
|
||||
{
|
||||
|
|
@ -171,6 +172,8 @@ class TelemetryAdapter(TelemetryDatasetMixin, Telemetry):
|
|||
return _GLOBAL_STORAGE["gauges"][name]
|
||||
|
||||
def _log_metric(self, event: MetricEvent) -> None:
|
||||
if self.meter is None:
|
||||
return
|
||||
if isinstance(event.value, int):
|
||||
counter = self._get_or_create_counter(event.metric, event.unit)
|
||||
counter.add(event.value, attributes=event.attributes)
|
||||
|
|
|
|||
|
|
@ -1,12 +0,0 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class SampleConfig(BaseModel):
|
||||
host: str = "localhost"
|
||||
port: int = 9999
|
||||
|
|
@ -1,17 +0,0 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from llama_stack.apis.telemetry import Telemetry
|
||||
|
||||
from .config import SampleConfig
|
||||
|
||||
|
||||
class SampleTelemetryImpl(Telemetry):
|
||||
def __init__(self, config: SampleConfig):
|
||||
self.config = config
|
||||
|
||||
async def initialize(self):
|
||||
pass
|
||||
|
|
@ -4,12 +4,14 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import Any, Dict
|
||||
|
||||
from .config import CodeInterpreterToolConfig
|
||||
|
||||
__all__ = ["CodeInterpreterToolConfig", "CodeInterpreterToolRuntimeImpl"]
|
||||
|
||||
|
||||
async def get_provider_impl(config: CodeInterpreterToolConfig, _deps):
|
||||
async def get_provider_impl(config: CodeInterpreterToolConfig, _deps: Dict[str, Any]):
|
||||
from .code_interpreter import CodeInterpreterToolRuntimeImpl
|
||||
|
||||
impl = CodeInterpreterToolRuntimeImpl(config)
|
||||
|
|
|
|||
|
|
@ -76,6 +76,7 @@ class CodeExecutionRequest:
|
|||
only_last_cell_fail: bool = True
|
||||
seed: int = 0
|
||||
strip_fpaths_in_stderr: bool = True
|
||||
use_bwrap: bool = True
|
||||
|
||||
|
||||
class CodeExecutor:
|
||||
|
|
@ -103,8 +104,6 @@ _set_seeds()\
|
|||
|
||||
script = "\n\n".join([seeds_prefix] + [CODE_ENV_PREFIX] + scripts)
|
||||
with tempfile.TemporaryDirectory() as dpath:
|
||||
bwrap_prefix = "bwrap " + generate_bwrap_command(bind_dirs=[dpath])
|
||||
cmd = [*bwrap_prefix.split(), sys.executable, "-c", script]
|
||||
code_fpath = os.path.join(dpath, "code.py")
|
||||
with open(code_fpath, "w") as f:
|
||||
f.write(script)
|
||||
|
|
@ -118,6 +117,13 @@ _set_seeds()\
|
|||
MPLBACKEND="module://matplotlib_custom_backend",
|
||||
PYTHONPATH=f"{DIRNAME}:{python_path}",
|
||||
)
|
||||
|
||||
if req.use_bwrap:
|
||||
bwrap_prefix = "bwrap " + generate_bwrap_command(bind_dirs=[dpath])
|
||||
cmd = [*bwrap_prefix.split(), sys.executable, "-c", script]
|
||||
else:
|
||||
cmd = [sys.executable, "-c", script]
|
||||
|
||||
stdout, stderr, returncode = do_subprocess(
|
||||
cmd=cmd,
|
||||
env=env,
|
||||
|
|
@ -162,7 +168,7 @@ def process_matplotlib_response(response, matplotlib_dump_dir: str):
|
|||
image_paths = []
|
||||
for i, img in enumerate(images):
|
||||
# create new directory for each day to better organize data:
|
||||
dump_dname = datetime.today().strftime("%Y-%m-%d")
|
||||
dump_dname = datetime.today().strftime("%Y-%m-%d") # noqa: DTZ002 - we don't care about timezones here since we are displaying the date
|
||||
dump_dpath = Path(matplotlib_dump_dir, dump_dname)
|
||||
dump_dpath.mkdir(parents=True, exist_ok=True)
|
||||
# save image into a file
|
||||
|
|
|
|||
|
|
@ -5,7 +5,9 @@
|
|||
# the root directory of this source tree.
|
||||
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import os
|
||||
import tempfile
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
|
|
@ -36,7 +38,7 @@ class CodeInterpreterToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime):
|
|||
async def initialize(self):
|
||||
pass
|
||||
|
||||
async def register_tool(self, tool: Tool):
|
||||
async def register_tool(self, tool: Tool) -> None:
|
||||
pass
|
||||
|
||||
async def unregister_tool(self, tool_id: str) -> None:
|
||||
|
|
@ -61,8 +63,10 @@ class CodeInterpreterToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime):
|
|||
|
||||
async def invoke_tool(self, tool_name: str, kwargs: Dict[str, Any]) -> ToolInvocationResult:
|
||||
script = kwargs["code"]
|
||||
req = CodeExecutionRequest(scripts=[script])
|
||||
res = self.code_executor.execute(req)
|
||||
# Use environment variable to control bwrap usage
|
||||
force_disable_bwrap = os.environ.get("DISABLE_CODE_SANDBOX", "").lower() in ("1", "true", "yes")
|
||||
req = CodeExecutionRequest(scripts=[script], use_bwrap=not force_disable_bwrap)
|
||||
res = await asyncio.to_thread(self.code_executor.execute, req)
|
||||
pieces = [res["process_status"]]
|
||||
for out_type in ["stdout", "stderr"]:
|
||||
res_out = res[out_type]
|
||||
|
|
|
|||
|
|
@ -4,8 +4,12 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import Any, Dict
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class CodeInterpreterToolConfig(BaseModel):
|
||||
pass
|
||||
@classmethod
|
||||
def sample_run_config(cls, __distro_dir__: str, **kwargs: Any) -> Dict[str, Any]:
|
||||
return {}
|
||||
|
|
|
|||
|
|
@ -4,8 +4,12 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import Any, Dict
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class RagToolRuntimeConfig(BaseModel):
|
||||
pass
|
||||
@classmethod
|
||||
def sample_run_config(cls, __distro_dir__: str, **kwargs: Any) -> Dict[str, Any]:
|
||||
return {}
|
||||
|
|
|
|||
|
|
@ -4,14 +4,14 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import Dict
|
||||
from typing import Any, Dict
|
||||
|
||||
from llama_stack.providers.datatypes import Api, ProviderSpec
|
||||
from llama_stack.providers.datatypes import Api
|
||||
|
||||
from .config import ChromaVectorIOConfig
|
||||
|
||||
|
||||
async def get_provider_impl(config: ChromaVectorIOConfig, deps: Dict[Api, ProviderSpec]):
|
||||
async def get_provider_impl(config: ChromaVectorIOConfig, deps: Dict[Api, Any]):
|
||||
from llama_stack.providers.remote.vector_io.chroma.chroma import (
|
||||
ChromaVectorIOAdapter,
|
||||
)
|
||||
|
|
|
|||
|
|
@ -13,5 +13,5 @@ class ChromaVectorIOConfig(BaseModel):
|
|||
db_path: str
|
||||
|
||||
@classmethod
|
||||
def sample_config(cls) -> Dict[str, Any]:
|
||||
return {"db_path": "{env.CHROMADB_PATH}"}
|
||||
def sample_run_config(cls, db_path: str = "${env.CHROMADB_PATH}", **kwargs: Any) -> Dict[str, Any]:
|
||||
return {"db_path": db_path}
|
||||
|
|
|
|||
|
|
@ -4,14 +4,14 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import Dict
|
||||
from typing import Any, Dict
|
||||
|
||||
from llama_stack.providers.datatypes import Api, ProviderSpec
|
||||
from llama_stack.providers.datatypes import Api
|
||||
|
||||
from .config import FaissVectorIOConfig
|
||||
|
||||
|
||||
async def get_provider_impl(config: FaissVectorIOConfig, deps: Dict[Api, ProviderSpec]):
|
||||
async def get_provider_impl(config: FaissVectorIOConfig, deps: Dict[Api, Any]):
|
||||
from .faiss import FaissVectorIOAdapter
|
||||
|
||||
assert isinstance(config, FaissVectorIOConfig), f"Unexpected config type: {type(config)}"
|
||||
|
|
|
|||
|
|
@ -4,14 +4,14 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import Dict
|
||||
from typing import Any, Dict
|
||||
|
||||
from llama_stack.providers.datatypes import Api, ProviderSpec
|
||||
from llama_stack.providers.datatypes import Api
|
||||
|
||||
from .config import MilvusVectorIOConfig
|
||||
|
||||
|
||||
async def get_provider_impl(config: MilvusVectorIOConfig, deps: Dict[Api, ProviderSpec]):
|
||||
async def get_provider_impl(config: MilvusVectorIOConfig, deps: Dict[Api, Any]):
|
||||
from llama_stack.providers.remote.vector_io.milvus.milvus import MilvusVectorIOAdapter
|
||||
|
||||
impl = MilvusVectorIOAdapter(config, deps[Api.inference])
|
||||
|
|
|
|||
19
llama_stack/providers/inline/vector_io/qdrant/__init__.py
Normal file
19
llama_stack/providers/inline/vector_io/qdrant/__init__.py
Normal file
|
|
@ -0,0 +1,19 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import Dict
|
||||
|
||||
from llama_stack.providers.datatypes import Api, ProviderSpec
|
||||
|
||||
from .config import QdrantVectorIOConfig
|
||||
|
||||
|
||||
async def get_adapter_impl(config: QdrantVectorIOConfig, deps: Dict[Api, ProviderSpec]):
|
||||
from llama_stack.providers.remote.vector_io.qdrant.qdrant import QdrantVectorIOAdapter
|
||||
|
||||
impl = QdrantVectorIOAdapter(config, deps[Api.inference])
|
||||
await impl.initialize()
|
||||
return impl
|
||||
23
llama_stack/providers/inline/vector_io/qdrant/config.py
Normal file
23
llama_stack/providers/inline/vector_io/qdrant/config.py
Normal file
|
|
@ -0,0 +1,23 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
|
||||
from typing import Any, Dict
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from llama_stack.schema_utils import json_schema_type
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class QdrantVectorIOConfig(BaseModel):
|
||||
path: str
|
||||
|
||||
@classmethod
|
||||
def sample_run_config(cls, __distro_dir__: str) -> Dict[str, Any]:
|
||||
return {
|
||||
"path": "${env.QDRANT_PATH:~/.llama/" + __distro_dir__ + "}/" + "qdrant.db",
|
||||
}
|
||||
|
|
@ -4,14 +4,14 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import Dict
|
||||
from typing import Any, Dict
|
||||
|
||||
from llama_stack.providers.datatypes import Api, ProviderSpec
|
||||
from llama_stack.providers.datatypes import Api
|
||||
|
||||
from .config import SQLiteVectorIOConfig
|
||||
|
||||
|
||||
async def get_provider_impl(config: SQLiteVectorIOConfig, deps: Dict[Api, ProviderSpec]):
|
||||
async def get_provider_impl(config: SQLiteVectorIOConfig, deps: Dict[Api, Any]):
|
||||
from .sqlite_vec import SQLiteVecVectorIOAdapter
|
||||
|
||||
assert isinstance(config, SQLiteVectorIOConfig), f"Unexpected config type: {type(config)}"
|
||||
|
|
|
|||
|
|
@ -7,11 +7,9 @@
|
|||
from typing import List
|
||||
|
||||
from llama_stack.providers.datatypes import (
|
||||
AdapterSpec,
|
||||
Api,
|
||||
InlineProviderSpec,
|
||||
ProviderSpec,
|
||||
remote_provider_spec,
|
||||
)
|
||||
from llama_stack.providers.utils.kvstore import kvstore_dependencies
|
||||
|
||||
|
|
@ -39,13 +37,4 @@ def available_providers() -> List[ProviderSpec]:
|
|||
Api.tool_groups,
|
||||
],
|
||||
),
|
||||
remote_provider_spec(
|
||||
api=Api.agents,
|
||||
adapter=AdapterSpec(
|
||||
adapter_type="sample",
|
||||
pip_packages=[],
|
||||
module="llama_stack.providers.remote.agents.sample",
|
||||
config_class="llama_stack.providers.remote.agents.sample.SampleConfig",
|
||||
),
|
||||
),
|
||||
]
|
||||
|
|
|
|||
|
|
@ -14,7 +14,7 @@ def available_providers() -> List[ProviderSpec]:
|
|||
InlineProviderSpec(
|
||||
api=Api.eval,
|
||||
provider_type="inline::meta-reference",
|
||||
pip_packages=[],
|
||||
pip_packages=["tree_sitter", "pythainlp", "langdetect", "emoji", "nltk"],
|
||||
module="llama_stack.providers.inline.eval.meta_reference",
|
||||
config_class="llama_stack.providers.inline.eval.meta_reference.MetaReferenceEvalConfig",
|
||||
api_dependencies=[
|
||||
|
|
|
|||
|
|
@ -68,15 +68,6 @@ def available_providers() -> List[ProviderSpec]:
|
|||
module="llama_stack.providers.inline.inference.sentence_transformers",
|
||||
config_class="llama_stack.providers.inline.inference.sentence_transformers.config.SentenceTransformersInferenceConfig",
|
||||
),
|
||||
remote_provider_spec(
|
||||
api=Api.inference,
|
||||
adapter=AdapterSpec(
|
||||
adapter_type="sample",
|
||||
pip_packages=[],
|
||||
module="llama_stack.providers.remote.inference.sample",
|
||||
config_class="llama_stack.providers.remote.inference.sample.SampleConfig",
|
||||
),
|
||||
),
|
||||
remote_provider_spec(
|
||||
api=Api.inference,
|
||||
adapter=AdapterSpec(
|
||||
|
|
|
|||
|
|
@ -27,27 +27,6 @@ def available_providers() -> List[ProviderSpec]:
|
|||
module="llama_stack.providers.inline.safety.prompt_guard",
|
||||
config_class="llama_stack.providers.inline.safety.prompt_guard.PromptGuardConfig",
|
||||
),
|
||||
InlineProviderSpec(
|
||||
api=Api.safety,
|
||||
provider_type="inline::meta-reference",
|
||||
pip_packages=[
|
||||
"transformers",
|
||||
"torch --index-url https://download.pytorch.org/whl/cpu",
|
||||
],
|
||||
module="llama_stack.providers.inline.safety.meta_reference",
|
||||
config_class="llama_stack.providers.inline.safety.meta_reference.SafetyConfig",
|
||||
api_dependencies=[
|
||||
Api.inference,
|
||||
],
|
||||
deprecation_error="""
|
||||
Provider `inline::meta-reference` for API `safety` does not work with the latest Llama Stack.
|
||||
|
||||
- if you are using Llama Guard v3, please use the `inline::llama-guard` provider instead.
|
||||
- if you are using Prompt Guard, please use the `inline::prompt-guard` provider instead.
|
||||
- if you are using Code Scanner, please use the `inline::code-scanner` provider instead.
|
||||
|
||||
""",
|
||||
),
|
||||
InlineProviderSpec(
|
||||
api=Api.safety,
|
||||
provider_type="inline::llama-guard",
|
||||
|
|
@ -67,15 +46,6 @@ Provider `inline::meta-reference` for API `safety` does not work with the latest
|
|||
module="llama_stack.providers.inline.safety.code_scanner",
|
||||
config_class="llama_stack.providers.inline.safety.code_scanner.CodeScannerConfig",
|
||||
),
|
||||
remote_provider_spec(
|
||||
api=Api.safety,
|
||||
adapter=AdapterSpec(
|
||||
adapter_type="sample",
|
||||
pip_packages=[],
|
||||
module="llama_stack.providers.remote.safety.sample",
|
||||
config_class="llama_stack.providers.remote.safety.sample.SampleConfig",
|
||||
),
|
||||
),
|
||||
remote_provider_spec(
|
||||
api=Api.safety,
|
||||
adapter=AdapterSpec(
|
||||
|
|
@ -85,4 +55,13 @@ Provider `inline::meta-reference` for API `safety` does not work with the latest
|
|||
config_class="llama_stack.providers.remote.safety.bedrock.BedrockSafetyConfig",
|
||||
),
|
||||
),
|
||||
remote_provider_spec(
|
||||
api=Api.safety,
|
||||
adapter=AdapterSpec(
|
||||
adapter_type="nvidia",
|
||||
pip_packages=["requests"],
|
||||
module="llama_stack.providers.remote.safety.nvidia",
|
||||
config_class="llama_stack.providers.remote.safety.nvidia.NVIDIASafetyConfig",
|
||||
),
|
||||
),
|
||||
]
|
||||
|
|
|
|||
|
|
@ -7,11 +7,9 @@
|
|||
from typing import List
|
||||
|
||||
from llama_stack.providers.datatypes import (
|
||||
AdapterSpec,
|
||||
Api,
|
||||
InlineProviderSpec,
|
||||
ProviderSpec,
|
||||
remote_provider_spec,
|
||||
)
|
||||
|
||||
|
||||
|
|
@ -28,13 +26,4 @@ def available_providers() -> List[ProviderSpec]:
|
|||
module="llama_stack.providers.inline.telemetry.meta_reference",
|
||||
config_class="llama_stack.providers.inline.telemetry.meta_reference.config.TelemetryConfig",
|
||||
),
|
||||
remote_provider_spec(
|
||||
api=Api.telemetry,
|
||||
adapter=AdapterSpec(
|
||||
adapter_type="sample",
|
||||
pip_packages=[],
|
||||
module="llama_stack.providers.remote.telemetry.sample",
|
||||
config_class="llama_stack.providers.remote.telemetry.sample.SampleConfig",
|
||||
),
|
||||
),
|
||||
]
|
||||
|
|
|
|||
|
|
@ -34,6 +34,8 @@ def available_providers() -> List[ProviderSpec]:
|
|||
config_class="llama_stack.providers.inline.vector_io.faiss.FaissVectorIOConfig",
|
||||
api_dependencies=[Api.inference],
|
||||
),
|
||||
# NOTE: sqlite-vec cannot be bundled into the container image because it does not have a
|
||||
# source distribution and the wheels are not available for all platforms.
|
||||
InlineProviderSpec(
|
||||
api=Api.vector_io,
|
||||
provider_type="inline::sqlite-vec",
|
||||
|
|
@ -90,15 +92,13 @@ def available_providers() -> List[ProviderSpec]:
|
|||
),
|
||||
api_dependencies=[Api.inference],
|
||||
),
|
||||
remote_provider_spec(
|
||||
InlineProviderSpec(
|
||||
api=Api.vector_io,
|
||||
adapter=AdapterSpec(
|
||||
adapter_type="sample",
|
||||
pip_packages=[],
|
||||
module="llama_stack.providers.remote.vector_io.sample",
|
||||
config_class="llama_stack.providers.remote.vector_io.sample.SampleVectorIOConfig",
|
||||
),
|
||||
api_dependencies=[],
|
||||
provider_type="inline::qdrant",
|
||||
pip_packages=["qdrant-client"],
|
||||
module="llama_stack.providers.inline.vector_io.qdrant",
|
||||
config_class="llama_stack.providers.inline.vector_io.qdrant.QdrantVectorIOConfig",
|
||||
api_dependencies=[Api.inference],
|
||||
),
|
||||
remote_provider_spec(
|
||||
Api.vector_io,
|
||||
|
|
|
|||
|
|
@ -1,17 +0,0 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import Any
|
||||
|
||||
from .config import SampleConfig
|
||||
|
||||
|
||||
async def get_adapter_impl(config: SampleConfig, _deps) -> Any:
|
||||
from .sample import SampleAgentsImpl
|
||||
|
||||
impl = SampleAgentsImpl(config)
|
||||
await impl.initialize()
|
||||
return impl
|
||||
|
|
@ -1,12 +0,0 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class SampleConfig(BaseModel):
|
||||
host: str = "localhost"
|
||||
port: int = 9999
|
||||
|
|
@ -1,17 +0,0 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from llama_stack.apis.agents import Agents
|
||||
|
||||
from .config import SampleConfig
|
||||
|
||||
|
||||
class SampleAgentsImpl(Agents):
|
||||
def __init__(self, config: SampleConfig):
|
||||
self.config = config
|
||||
|
||||
async def initialize(self):
|
||||
pass
|
||||
|
|
@ -3,9 +3,10 @@
|
|||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
from typing import Any, Dict
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from llama_stack.distribution.utils.config_dirs import RUNTIME_BASE_DIR
|
||||
from llama_stack.providers.utils.kvstore.config import (
|
||||
KVStoreConfig,
|
||||
SqliteKVStoreConfig,
|
||||
|
|
@ -13,6 +14,13 @@ from llama_stack.providers.utils.kvstore.config import (
|
|||
|
||||
|
||||
class HuggingfaceDatasetIOConfig(BaseModel):
|
||||
kvstore: KVStoreConfig = SqliteKVStoreConfig(
|
||||
db_path=(RUNTIME_BASE_DIR / "huggingface_datasetio.db").as_posix()
|
||||
) # Uses SQLite config specific to HF storage
|
||||
kvstore: KVStoreConfig
|
||||
|
||||
@classmethod
|
||||
def sample_run_config(cls, __distro_dir__: str, **kwargs: Any) -> Dict[str, Any]:
|
||||
return {
|
||||
"kvstore": SqliteKVStoreConfig.sample_run_config(
|
||||
__distro_dir__=__distro_dir__,
|
||||
db_name="huggingface_datasetio.db",
|
||||
)
|
||||
}
|
||||
|
|
|
|||
|
|
@ -4,13 +4,13 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
from typing import Any, Dict, List, Optional
|
||||
from urllib.parse import parse_qs, urlparse
|
||||
|
||||
import datasets as hf_datasets
|
||||
|
||||
from llama_stack.apis.datasetio import DatasetIO, PaginatedRowsResult
|
||||
from llama_stack.apis.datasetio import DatasetIO, IterrowsResponse
|
||||
from llama_stack.apis.datasets import Dataset
|
||||
from llama_stack.providers.datatypes import DatasetsProtocolPrivate
|
||||
from llama_stack.providers.utils.datasetio.url_utils import get_dataframe_from_url
|
||||
from llama_stack.providers.utils.kvstore import kvstore_impl
|
||||
|
||||
from .config import HuggingfaceDatasetIOConfig
|
||||
|
|
@ -18,22 +18,14 @@ from .config import HuggingfaceDatasetIOConfig
|
|||
DATASETS_PREFIX = "datasets:"
|
||||
|
||||
|
||||
def load_hf_dataset(dataset_def: Dataset):
|
||||
if dataset_def.metadata.get("path", None):
|
||||
dataset = hf_datasets.load_dataset(**dataset_def.metadata)
|
||||
else:
|
||||
df = get_dataframe_from_url(dataset_def.url)
|
||||
def parse_hf_params(dataset_def: Dataset):
|
||||
uri = dataset_def.source.uri
|
||||
parsed_uri = urlparse(uri)
|
||||
params = parse_qs(parsed_uri.query)
|
||||
params = {k: v[0] for k, v in params.items()}
|
||||
path = parsed_uri.path.lstrip("/")
|
||||
|
||||
if df is None:
|
||||
raise ValueError(f"Failed to load dataset from {dataset_def.url}")
|
||||
|
||||
dataset = hf_datasets.Dataset.from_pandas(df)
|
||||
|
||||
# drop columns not specified by schema
|
||||
if dataset_def.dataset_schema:
|
||||
dataset = dataset.select_columns(list(dataset_def.dataset_schema.keys()))
|
||||
|
||||
return dataset
|
||||
return path, params
|
||||
|
||||
|
||||
class HuggingfaceDatasetIOImpl(DatasetIO, DatasetsProtocolPrivate):
|
||||
|
|
@ -64,7 +56,7 @@ class HuggingfaceDatasetIOImpl(DatasetIO, DatasetsProtocolPrivate):
|
|||
key = f"{DATASETS_PREFIX}{dataset_def.identifier}"
|
||||
await self.kvstore.set(
|
||||
key=key,
|
||||
value=dataset_def.json(),
|
||||
value=dataset_def.model_dump_json(),
|
||||
)
|
||||
self.dataset_infos[dataset_def.identifier] = dataset_def
|
||||
|
||||
|
|
@ -73,41 +65,34 @@ class HuggingfaceDatasetIOImpl(DatasetIO, DatasetsProtocolPrivate):
|
|||
await self.kvstore.delete(key=key)
|
||||
del self.dataset_infos[dataset_id]
|
||||
|
||||
async def get_rows_paginated(
|
||||
async def iterrows(
|
||||
self,
|
||||
dataset_id: str,
|
||||
rows_in_page: int,
|
||||
page_token: Optional[str] = None,
|
||||
filter_condition: Optional[str] = None,
|
||||
) -> PaginatedRowsResult:
|
||||
start_index: Optional[int] = None,
|
||||
limit: Optional[int] = None,
|
||||
) -> IterrowsResponse:
|
||||
dataset_def = self.dataset_infos[dataset_id]
|
||||
loaded_dataset = load_hf_dataset(dataset_def)
|
||||
path, params = parse_hf_params(dataset_def)
|
||||
loaded_dataset = hf_datasets.load_dataset(path, **params)
|
||||
|
||||
if page_token and not page_token.isnumeric():
|
||||
raise ValueError("Invalid page_token")
|
||||
start_index = start_index or 0
|
||||
|
||||
if page_token is None or len(page_token) == 0:
|
||||
next_page_token = 0
|
||||
else:
|
||||
next_page_token = int(page_token)
|
||||
|
||||
start = next_page_token
|
||||
if rows_in_page == -1:
|
||||
if limit is None or limit == -1:
|
||||
end = len(loaded_dataset)
|
||||
else:
|
||||
end = min(start + rows_in_page, len(loaded_dataset))
|
||||
end = min(start_index + limit, len(loaded_dataset))
|
||||
|
||||
rows = [loaded_dataset[i] for i in range(start, end)]
|
||||
rows = [loaded_dataset[i] for i in range(start_index, end)]
|
||||
|
||||
return PaginatedRowsResult(
|
||||
rows=rows,
|
||||
total_count=len(rows),
|
||||
next_page_token=str(end),
|
||||
return IterrowsResponse(
|
||||
data=rows,
|
||||
next_start_index=end if end < len(loaded_dataset) else None,
|
||||
)
|
||||
|
||||
async def append_rows(self, dataset_id: str, rows: List[Dict[str, Any]]) -> None:
|
||||
dataset_def = self.dataset_infos[dataset_id]
|
||||
loaded_dataset = load_hf_dataset(dataset_def)
|
||||
path, params = parse_hf_params(dataset_def)
|
||||
loaded_dataset = hf_datasets.load_dataset(path, **params)
|
||||
|
||||
# Convert rows to HF Dataset format
|
||||
new_dataset = hf_datasets.Dataset.from_list(rows)
|
||||
|
|
|
|||
|
|
@ -4,6 +4,7 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import Any, Dict
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
|
@ -20,3 +21,15 @@ class DatabricksImplConfig(BaseModel):
|
|||
default=None,
|
||||
description="The Databricks API token",
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def sample_run_config(
|
||||
cls,
|
||||
url: str = "${env.DATABRICKS_URL}",
|
||||
api_token: str = "${env.DATABRICKS_API_TOKEN}",
|
||||
**kwargs: Any,
|
||||
) -> Dict[str, Any]:
|
||||
return {
|
||||
"url": url,
|
||||
"api_token": api_token,
|
||||
}
|
||||
|
|
|
|||
|
|
@ -24,10 +24,6 @@ MODEL_ENTRIES = [
|
|||
"accounts/fireworks/models/llama-v3p1-405b-instruct",
|
||||
CoreModelId.llama3_1_405b_instruct.value,
|
||||
),
|
||||
build_hf_repo_model_entry(
|
||||
"accounts/fireworks/models/llama-v3p2-1b-instruct",
|
||||
CoreModelId.llama3_2_1b_instruct.value,
|
||||
),
|
||||
build_hf_repo_model_entry(
|
||||
"accounts/fireworks/models/llama-v3p2-3b-instruct",
|
||||
CoreModelId.llama3_2_3b_instruct.value,
|
||||
|
|
|
|||
|
|
@ -6,6 +6,7 @@
|
|||
|
||||
import logging
|
||||
import warnings
|
||||
from functools import lru_cache
|
||||
from typing import AsyncIterator, List, Optional, Union
|
||||
|
||||
from openai import APIConnectionError, AsyncOpenAI, BadRequestError
|
||||
|
|
@ -82,12 +83,42 @@ class NVIDIAInferenceAdapter(Inference, ModelRegistryHelper):
|
|||
# )
|
||||
|
||||
self._config = config
|
||||
# make sure the client lives longer than any async calls
|
||||
self._client = AsyncOpenAI(
|
||||
base_url=f"{self._config.url}/v1",
|
||||
api_key=(self._config.api_key.get_secret_value() if self._config.api_key else "NO KEY"),
|
||||
timeout=self._config.timeout,
|
||||
)
|
||||
|
||||
@lru_cache # noqa: B019
|
||||
def _get_client(self, provider_model_id: str) -> AsyncOpenAI:
|
||||
"""
|
||||
For hosted models, https://integrate.api.nvidia.com/v1 is the primary base_url. However,
|
||||
some models are hosted on different URLs. This function returns the appropriate client
|
||||
for the given provider_model_id.
|
||||
|
||||
This relies on lru_cache and self._default_client to avoid creating a new client for each request
|
||||
or for each model that is hosted on https://integrate.api.nvidia.com/v1.
|
||||
|
||||
:param provider_model_id: The provider model ID
|
||||
:return: An OpenAI client
|
||||
"""
|
||||
|
||||
@lru_cache # noqa: B019
|
||||
def _get_client_for_base_url(base_url: str) -> AsyncOpenAI:
|
||||
"""
|
||||
Maintain a single OpenAI client per base_url.
|
||||
"""
|
||||
return AsyncOpenAI(
|
||||
base_url=base_url,
|
||||
api_key=(self._config.api_key.get_secret_value() if self._config.api_key else "NO KEY"),
|
||||
timeout=self._config.timeout,
|
||||
)
|
||||
|
||||
special_model_urls = {
|
||||
"meta/llama-3.2-11b-vision-instruct": "https://ai.api.nvidia.com/v1/gr/meta/llama-3.2-11b-vision-instruct",
|
||||
"meta/llama-3.2-90b-vision-instruct": "https://ai.api.nvidia.com/v1/gr/meta/llama-3.2-90b-vision-instruct",
|
||||
}
|
||||
|
||||
base_url = f"{self._config.url}/v1"
|
||||
if _is_nvidia_hosted(self._config) and provider_model_id in special_model_urls:
|
||||
base_url = special_model_urls[provider_model_id]
|
||||
|
||||
return _get_client_for_base_url(base_url)
|
||||
|
||||
async def completion(
|
||||
self,
|
||||
|
|
@ -105,9 +136,10 @@ class NVIDIAInferenceAdapter(Inference, ModelRegistryHelper):
|
|||
|
||||
await check_health(self._config) # this raises errors
|
||||
|
||||
provider_model_id = self.get_provider_model_id(model_id)
|
||||
request = convert_completion_request(
|
||||
request=CompletionRequest(
|
||||
model=self.get_provider_model_id(model_id),
|
||||
model=provider_model_id,
|
||||
content=content,
|
||||
sampling_params=sampling_params,
|
||||
response_format=response_format,
|
||||
|
|
@ -118,7 +150,7 @@ class NVIDIAInferenceAdapter(Inference, ModelRegistryHelper):
|
|||
)
|
||||
|
||||
try:
|
||||
response = await self._client.completions.create(**request)
|
||||
response = await self._get_client(provider_model_id).completions.create(**request)
|
||||
except APIConnectionError as e:
|
||||
raise ConnectionError(f"Failed to connect to NVIDIA NIM at {self._config.url}: {e}") from e
|
||||
|
||||
|
|
@ -206,6 +238,7 @@ class NVIDIAInferenceAdapter(Inference, ModelRegistryHelper):
|
|||
|
||||
await check_health(self._config) # this raises errors
|
||||
|
||||
provider_model_id = self.get_provider_model_id(model_id)
|
||||
request = await convert_chat_completion_request(
|
||||
request=ChatCompletionRequest(
|
||||
model=self.get_provider_model_id(model_id),
|
||||
|
|
@ -221,7 +254,7 @@ class NVIDIAInferenceAdapter(Inference, ModelRegistryHelper):
|
|||
)
|
||||
|
||||
try:
|
||||
response = await self._client.chat.completions.create(**request)
|
||||
response = await self._get_client(provider_model_id).chat.completions.create(**request)
|
||||
except APIConnectionError as e:
|
||||
raise ConnectionError(f"Failed to connect to NVIDIA NIM at {self._config.url}: {e}") from e
|
||||
|
||||
|
|
|
|||
|
|
@ -4,12 +4,15 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import AsyncGenerator, List, Optional
|
||||
from typing import Any, AsyncGenerator, Dict, List, Optional
|
||||
|
||||
from llama_stack_client import LlamaStackClient
|
||||
from llama_stack_client import AsyncLlamaStackClient
|
||||
|
||||
from llama_stack.apis.common.content_types import InterleavedContent
|
||||
from llama_stack.apis.inference import (
|
||||
ChatCompletionResponse,
|
||||
ChatCompletionResponseStreamChunk,
|
||||
CompletionMessage,
|
||||
EmbeddingsResponse,
|
||||
EmbeddingTaskType,
|
||||
Inference,
|
||||
|
|
@ -24,6 +27,7 @@ from llama_stack.apis.inference import (
|
|||
ToolPromptFormat,
|
||||
)
|
||||
from llama_stack.apis.models import Model
|
||||
from llama_stack.distribution.library_client import convert_pydantic_to_json_value, convert_to_pydantic
|
||||
from llama_stack.providers.utils.inference.model_registry import ModelRegistryHelper
|
||||
|
||||
from .config import PassthroughImplConfig
|
||||
|
|
@ -46,7 +50,7 @@ class PassthroughInferenceAdapter(Inference):
|
|||
async def register_model(self, model: Model) -> Model:
|
||||
return model
|
||||
|
||||
def _get_client(self) -> LlamaStackClient:
|
||||
def _get_client(self) -> AsyncLlamaStackClient:
|
||||
passthrough_url = None
|
||||
passthrough_api_key = None
|
||||
provider_data = None
|
||||
|
|
@ -71,7 +75,7 @@ class PassthroughInferenceAdapter(Inference):
|
|||
)
|
||||
passthrough_api_key = provider_data.passthrough_api_key
|
||||
|
||||
return LlamaStackClient(
|
||||
return AsyncLlamaStackClient(
|
||||
base_url=passthrough_url,
|
||||
api_key=passthrough_api_key,
|
||||
provider_data=provider_data,
|
||||
|
|
@ -91,7 +95,7 @@ class PassthroughInferenceAdapter(Inference):
|
|||
client = self._get_client()
|
||||
model = await self.model_store.get_model(model_id)
|
||||
|
||||
params = {
|
||||
request_params = {
|
||||
"model_id": model.provider_resource_id,
|
||||
"content": content,
|
||||
"sampling_params": sampling_params,
|
||||
|
|
@ -100,10 +104,13 @@ class PassthroughInferenceAdapter(Inference):
|
|||
"logprobs": logprobs,
|
||||
}
|
||||
|
||||
params = {key: value for key, value in params.items() if value is not None}
|
||||
request_params = {key: value for key, value in request_params.items() if value is not None}
|
||||
|
||||
# cast everything to json dict
|
||||
json_params = self.cast_value_to_json_dict(request_params)
|
||||
|
||||
# only pass through the not None params
|
||||
return client.inference.completion(**params)
|
||||
return await client.inference.completion(**json_params)
|
||||
|
||||
async def chat_completion(
|
||||
self,
|
||||
|
|
@ -120,10 +127,14 @@ class PassthroughInferenceAdapter(Inference):
|
|||
) -> AsyncGenerator:
|
||||
if sampling_params is None:
|
||||
sampling_params = SamplingParams()
|
||||
client = self._get_client()
|
||||
model = await self.model_store.get_model(model_id)
|
||||
|
||||
params = {
|
||||
# TODO: revisit this remove tool_calls from messages logic
|
||||
for message in messages:
|
||||
if hasattr(message, "tool_calls"):
|
||||
message.tool_calls = None
|
||||
|
||||
request_params = {
|
||||
"model_id": model.provider_resource_id,
|
||||
"messages": messages,
|
||||
"sampling_params": sampling_params,
|
||||
|
|
@ -135,10 +146,41 @@ class PassthroughInferenceAdapter(Inference):
|
|||
"logprobs": logprobs,
|
||||
}
|
||||
|
||||
params = {key: value for key, value in params.items() if value is not None}
|
||||
|
||||
# only pass through the not None params
|
||||
return client.inference.chat_completion(**params)
|
||||
request_params = {key: value for key, value in request_params.items() if value is not None}
|
||||
|
||||
# cast everything to json dict
|
||||
json_params = self.cast_value_to_json_dict(request_params)
|
||||
|
||||
if stream:
|
||||
return self._stream_chat_completion(json_params)
|
||||
else:
|
||||
return await self._nonstream_chat_completion(json_params)
|
||||
|
||||
async def _nonstream_chat_completion(self, json_params: Dict[str, Any]) -> ChatCompletionResponse:
|
||||
client = self._get_client()
|
||||
response = await client.inference.chat_completion(**json_params)
|
||||
|
||||
return ChatCompletionResponse(
|
||||
completion_message=CompletionMessage(
|
||||
content=response.completion_message.content.text,
|
||||
stop_reason=response.completion_message.stop_reason,
|
||||
tool_calls=response.completion_message.tool_calls,
|
||||
),
|
||||
logprobs=response.logprobs,
|
||||
)
|
||||
|
||||
async def _stream_chat_completion(self, json_params: Dict[str, Any]) -> AsyncGenerator:
|
||||
client = self._get_client()
|
||||
stream_response = await client.inference.chat_completion(**json_params)
|
||||
|
||||
async for chunk in stream_response:
|
||||
chunk = chunk.to_dict()
|
||||
|
||||
# temporary hack to remove the metrics from the response
|
||||
chunk["metrics"] = []
|
||||
chunk = convert_to_pydantic(ChatCompletionResponseStreamChunk, chunk)
|
||||
yield chunk
|
||||
|
||||
async def embeddings(
|
||||
self,
|
||||
|
|
@ -151,10 +193,29 @@ class PassthroughInferenceAdapter(Inference):
|
|||
client = self._get_client()
|
||||
model = await self.model_store.get_model(model_id)
|
||||
|
||||
return client.inference.embeddings(
|
||||
return await client.inference.embeddings(
|
||||
model_id=model.provider_resource_id,
|
||||
contents=contents,
|
||||
text_truncation=text_truncation,
|
||||
output_dimension=output_dimension,
|
||||
task_type=task_type,
|
||||
)
|
||||
|
||||
def cast_value_to_json_dict(self, request_params: Dict[str, Any]) -> Dict[str, Any]:
|
||||
json_params = {}
|
||||
for key, value in request_params.items():
|
||||
json_input = convert_pydantic_to_json_value(value)
|
||||
if isinstance(json_input, dict):
|
||||
json_input = {k: v for k, v in json_input.items() if v is not None}
|
||||
elif isinstance(json_input, list):
|
||||
json_input = [x for x in json_input if x is not None]
|
||||
new_input = []
|
||||
for x in json_input:
|
||||
if isinstance(x, dict):
|
||||
x = {k: v for k, v in x.items() if v is not None}
|
||||
new_input.append(x)
|
||||
json_input = new_input
|
||||
|
||||
json_params[key] = json_input
|
||||
|
||||
return json_params
|
||||
|
|
|
|||
|
|
@ -5,10 +5,11 @@
|
|||
# the root directory of this source tree.
|
||||
|
||||
from .config import RunpodImplConfig
|
||||
from .runpod import RunpodInferenceAdapter
|
||||
|
||||
|
||||
async def get_adapter_impl(config: RunpodImplConfig, _deps):
|
||||
from .runpod import RunpodInferenceAdapter
|
||||
|
||||
assert isinstance(config, RunpodImplConfig), f"Unexpected config type: {type(config)}"
|
||||
impl = RunpodInferenceAdapter(config)
|
||||
await impl.initialize()
|
||||
|
|
|
|||
|
|
@ -4,7 +4,7 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import Optional
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
|
@ -21,3 +21,10 @@ class RunpodImplConfig(BaseModel):
|
|||
default=None,
|
||||
description="The API token",
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def sample_run_config(cls, **kwargs: Any) -> Dict[str, Any]:
|
||||
return {
|
||||
"url": "${env.RUNPOD_URL:}",
|
||||
"api_token": "${env.RUNPOD_API_TOKEN:}",
|
||||
}
|
||||
|
|
|
|||
|
|
@ -8,7 +8,6 @@ from typing import AsyncGenerator
|
|||
from openai import OpenAI
|
||||
|
||||
from llama_stack.apis.inference import * # noqa: F403
|
||||
from llama_stack.models.llama.datatypes import Message
|
||||
|
||||
# from llama_stack.providers.datatypes import ModelsProtocolPrivate
|
||||
from llama_stack.providers.utils.inference.model_registry import ModelRegistryHelper
|
||||
|
|
|
|||
|
|
@ -42,9 +42,7 @@ from llama_stack.models.llama.datatypes import (
|
|||
TopKSamplingStrategy,
|
||||
TopPSamplingStrategy,
|
||||
)
|
||||
from llama_stack.providers.utils.inference.model_registry import (
|
||||
ModelRegistryHelper,
|
||||
)
|
||||
from llama_stack.providers.utils.inference.model_registry import ModelRegistryHelper
|
||||
from llama_stack.providers.utils.inference.openai_compat import (
|
||||
process_chat_completion_stream_response,
|
||||
)
|
||||
|
|
@ -293,14 +291,12 @@ class SambaNovaInferenceAdapter(ModelRegistryHelper, Inference):
|
|||
if not tool_calls:
|
||||
return []
|
||||
|
||||
for call in tool_calls:
|
||||
call_function_arguments = json.loads(call.function.arguments)
|
||||
|
||||
compitable_tool_calls = [
|
||||
ToolCall(
|
||||
call_id=call.id,
|
||||
tool_name=call.function.name,
|
||||
arguments=call_function_arguments,
|
||||
arguments=json.loads(call.function.arguments),
|
||||
arguments_json=call.function.arguments,
|
||||
)
|
||||
for call in tool_calls
|
||||
]
|
||||
|
|
|
|||
|
|
@ -1,17 +0,0 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import Any
|
||||
|
||||
from .config import SampleConfig
|
||||
|
||||
|
||||
async def get_adapter_impl(config: SampleConfig, _deps) -> Any:
|
||||
from .sample import SampleInferenceImpl
|
||||
|
||||
impl = SampleInferenceImpl(config)
|
||||
await impl.initialize()
|
||||
return impl
|
||||
|
|
@ -1,23 +0,0 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from llama_stack.apis.inference import Inference
|
||||
from llama_stack.apis.models import Model
|
||||
|
||||
from .config import SampleConfig
|
||||
|
||||
|
||||
class SampleInferenceImpl(Inference):
|
||||
def __init__(self, config: SampleConfig):
|
||||
self.config = config
|
||||
|
||||
async def register_model(self, model: Model) -> None:
|
||||
# these are the model names the Llama Stack will use to route requests to this provider
|
||||
# perform validation here if necessary
|
||||
pass
|
||||
|
||||
async def initialize(self):
|
||||
pass
|
||||
|
|
@ -26,5 +26,5 @@ class TogetherImplConfig(BaseModel):
|
|||
def sample_run_config(cls, **kwargs) -> Dict[str, Any]:
|
||||
return {
|
||||
"url": "https://api.together.xyz/v1",
|
||||
"api_key": "${env.TOGETHER_API_KEY}",
|
||||
"api_key": "${env.TOGETHER_API_KEY:}",
|
||||
}
|
||||
|
|
|
|||
|
|
@ -6,7 +6,7 @@
|
|||
|
||||
from typing import AsyncGenerator, List, Optional, Union
|
||||
|
||||
from together import Together
|
||||
from together import AsyncTogether
|
||||
|
||||
from llama_stack.apis.common.content_types import (
|
||||
InterleavedContent,
|
||||
|
|
@ -59,12 +59,15 @@ class TogetherInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProvi
|
|||
def __init__(self, config: TogetherImplConfig) -> None:
|
||||
ModelRegistryHelper.__init__(self, MODEL_ENTRIES)
|
||||
self.config = config
|
||||
self._client = None
|
||||
|
||||
async def initialize(self) -> None:
|
||||
pass
|
||||
|
||||
async def shutdown(self) -> None:
|
||||
pass
|
||||
if self._client:
|
||||
await self._client.close()
|
||||
self._client = None
|
||||
|
||||
async def completion(
|
||||
self,
|
||||
|
|
@ -91,35 +94,32 @@ class TogetherInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProvi
|
|||
else:
|
||||
return await self._nonstream_completion(request)
|
||||
|
||||
def _get_client(self) -> Together:
|
||||
together_api_key = None
|
||||
config_api_key = self.config.api_key.get_secret_value() if self.config.api_key else None
|
||||
if config_api_key:
|
||||
together_api_key = config_api_key
|
||||
else:
|
||||
provider_data = self.get_request_provider_data()
|
||||
if provider_data is None or not provider_data.together_api_key:
|
||||
raise ValueError(
|
||||
'Pass Together API Key in the header X-LlamaStack-Provider-Data as { "together_api_key": <your api key>}'
|
||||
)
|
||||
together_api_key = provider_data.together_api_key
|
||||
return Together(api_key=together_api_key)
|
||||
def _get_client(self) -> AsyncTogether:
|
||||
if not self._client:
|
||||
together_api_key = None
|
||||
config_api_key = self.config.api_key.get_secret_value() if self.config.api_key else None
|
||||
if config_api_key:
|
||||
together_api_key = config_api_key
|
||||
else:
|
||||
provider_data = self.get_request_provider_data()
|
||||
if provider_data is None or not provider_data.together_api_key:
|
||||
raise ValueError(
|
||||
'Pass Together API Key in the header X-LlamaStack-Provider-Data as { "together_api_key": <your api key>}'
|
||||
)
|
||||
together_api_key = provider_data.together_api_key
|
||||
self._client = AsyncTogether(api_key=together_api_key)
|
||||
return self._client
|
||||
|
||||
async def _nonstream_completion(self, request: CompletionRequest) -> ChatCompletionResponse:
|
||||
params = await self._get_params(request)
|
||||
r = self._get_client().completions.create(**params)
|
||||
client = self._get_client()
|
||||
r = await client.completions.create(**params)
|
||||
return process_completion_response(r)
|
||||
|
||||
async def _stream_completion(self, request: CompletionRequest) -> AsyncGenerator:
|
||||
params = await self._get_params(request)
|
||||
|
||||
# if we shift to TogetherAsyncClient, we won't need this wrapper
|
||||
async def _to_async_generator():
|
||||
s = self._get_client().completions.create(**params)
|
||||
for chunk in s:
|
||||
yield chunk
|
||||
|
||||
stream = _to_async_generator()
|
||||
client = await self._get_client()
|
||||
stream = await client.completions.create(**params)
|
||||
async for chunk in process_completion_stream_response(stream):
|
||||
yield chunk
|
||||
|
||||
|
|
@ -184,25 +184,21 @@ class TogetherInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProvi
|
|||
|
||||
async def _nonstream_chat_completion(self, request: ChatCompletionRequest) -> ChatCompletionResponse:
|
||||
params = await self._get_params(request)
|
||||
client = self._get_client()
|
||||
if "messages" in params:
|
||||
r = self._get_client().chat.completions.create(**params)
|
||||
r = await client.chat.completions.create(**params)
|
||||
else:
|
||||
r = self._get_client().completions.create(**params)
|
||||
r = await client.completions.create(**params)
|
||||
return process_chat_completion_response(r, request)
|
||||
|
||||
async def _stream_chat_completion(self, request: ChatCompletionRequest) -> AsyncGenerator:
|
||||
params = await self._get_params(request)
|
||||
client = self._get_client()
|
||||
if "messages" in params:
|
||||
stream = await client.chat.completions.create(**params)
|
||||
else:
|
||||
stream = await client.completions.create(**params)
|
||||
|
||||
# if we shift to TogetherAsyncClient, we won't need this wrapper
|
||||
async def _to_async_generator():
|
||||
if "messages" in params:
|
||||
s = self._get_client().chat.completions.create(**params)
|
||||
else:
|
||||
s = self._get_client().completions.create(**params)
|
||||
for chunk in s:
|
||||
yield chunk
|
||||
|
||||
stream = _to_async_generator()
|
||||
async for chunk in process_chat_completion_stream_response(stream, request):
|
||||
yield chunk
|
||||
|
||||
|
|
@ -240,7 +236,8 @@ class TogetherInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProvi
|
|||
assert all(not content_has_media(content) for content in contents), (
|
||||
"Together does not support media for embeddings"
|
||||
)
|
||||
r = self._get_client().embeddings.create(
|
||||
client = self._get_client()
|
||||
r = await client.embeddings.create(
|
||||
model=model.provider_resource_id,
|
||||
input=[interleaved_content_as_str(content) for content in contents],
|
||||
)
|
||||
|
|
|
|||
|
|
@ -25,6 +25,10 @@ class VLLMInferenceAdapterConfig(BaseModel):
|
|||
default="fake",
|
||||
description="The API token",
|
||||
)
|
||||
tls_verify: bool = Field(
|
||||
default=True,
|
||||
description="Whether to verify TLS certificates",
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def sample_run_config(
|
||||
|
|
@ -36,4 +40,5 @@ class VLLMInferenceAdapterConfig(BaseModel):
|
|||
"url": url,
|
||||
"max_tokens": "${env.VLLM_MAX_TOKENS:4096}",
|
||||
"api_token": "${env.VLLM_API_TOKEN:fake}",
|
||||
"tls_verify": "${env.VLLM_TLS_VERIFY:true}",
|
||||
}
|
||||
|
|
|
|||
|
|
@ -7,6 +7,7 @@ import json
|
|||
import logging
|
||||
from typing import AsyncGenerator, List, Optional, Union
|
||||
|
||||
import httpx
|
||||
from openai import AsyncOpenAI
|
||||
from openai.types.chat.chat_completion_chunk import (
|
||||
ChatCompletionChunk as OpenAIChatCompletionChunk,
|
||||
|
|
@ -89,15 +90,12 @@ def _convert_to_vllm_tool_calls_in_response(
|
|||
if not tool_calls:
|
||||
return []
|
||||
|
||||
call_function_arguments = None
|
||||
for call in tool_calls:
|
||||
call_function_arguments = json.loads(call.function.arguments)
|
||||
|
||||
return [
|
||||
ToolCall(
|
||||
call_id=call.id,
|
||||
tool_name=call.function.name,
|
||||
arguments=call_function_arguments,
|
||||
arguments=json.loads(call.function.arguments),
|
||||
arguments_json=call.function.arguments,
|
||||
)
|
||||
for call in tool_calls
|
||||
]
|
||||
|
|
@ -182,6 +180,7 @@ async def _process_vllm_chat_completion_stream_response(
|
|||
call_id=tool_call_buf.call_id,
|
||||
tool_name=tool_call_buf.tool_name,
|
||||
arguments=args,
|
||||
arguments_json=args_str,
|
||||
),
|
||||
parse_status=ToolCallParseStatus.succeeded,
|
||||
),
|
||||
|
|
@ -229,7 +228,11 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate):
|
|||
|
||||
async def initialize(self) -> None:
|
||||
log.info(f"Initializing VLLM client with base_url={self.config.url}")
|
||||
self.client = AsyncOpenAI(base_url=self.config.url, api_key=self.config.api_token)
|
||||
self.client = AsyncOpenAI(
|
||||
base_url=self.config.url,
|
||||
api_key=self.config.api_token,
|
||||
http_client=None if self.config.tls_verify else httpx.AsyncClient(verify=False),
|
||||
)
|
||||
|
||||
async def shutdown(self) -> None:
|
||||
pass
|
||||
|
|
|
|||
|
|
@ -4,14 +4,15 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
|
||||
from typing import Any
|
||||
|
||||
from .config import SampleConfig
|
||||
from .config import NVIDIASafetyConfig
|
||||
|
||||
|
||||
async def get_adapter_impl(config: SampleConfig, _deps) -> Any:
|
||||
from .sample import SampleTelemetryImpl
|
||||
async def get_adapter_impl(config: NVIDIASafetyConfig, _deps) -> Any:
|
||||
from .nvidia import NVIDIASafetyAdapter
|
||||
|
||||
impl = SampleTelemetryImpl(config)
|
||||
impl = NVIDIASafetyAdapter(config)
|
||||
await impl.initialize()
|
||||
return impl
|
||||
37
llama_stack/providers/remote/safety/nvidia/config.py
Normal file
37
llama_stack/providers/remote/safety/nvidia/config.py
Normal file
|
|
@ -0,0 +1,37 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
import os
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from llama_stack.schema_utils import json_schema_type
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class NVIDIASafetyConfig(BaseModel):
|
||||
"""
|
||||
Configuration for the NVIDIA Guardrail microservice endpoint.
|
||||
|
||||
Attributes:
|
||||
guardrails_service_url (str): A base url for accessing the NVIDIA guardrail endpoint, e.g. http://0.0.0.0:7331
|
||||
config_id (str): The ID of the guardrails configuration to use from the configuration store
|
||||
(https://developer.nvidia.com/docs/nemo-microservices/guardrails/source/guides/configuration-store-guide.html)
|
||||
|
||||
"""
|
||||
|
||||
guardrails_service_url: str = Field(
|
||||
default_factory=lambda: os.getenv("GUARDRAILS_SERVICE_URL", "http://0.0.0.0:7331"),
|
||||
description="The url for accessing the guardrails service",
|
||||
)
|
||||
config_id: Optional[str] = Field(default="self-check", description="Config ID to use from the config store")
|
||||
|
||||
@classmethod
|
||||
def sample_run_config(cls, **kwargs) -> Dict[str, Any]:
|
||||
return {
|
||||
"guardrails_service_url": "${env.GUARDRAILS_SERVICE_URL:http://localhost:7331}",
|
||||
"config_id": "self-check",
|
||||
}
|
||||
154
llama_stack/providers/remote/safety/nvidia/nvidia.py
Normal file
154
llama_stack/providers/remote/safety/nvidia/nvidia.py
Normal file
|
|
@ -0,0 +1,154 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import logging
|
||||
from typing import Any, List, Optional
|
||||
|
||||
import requests
|
||||
|
||||
from llama_stack.apis.inference import Message
|
||||
from llama_stack.apis.safety import RunShieldResponse, Safety, SafetyViolation, ViolationLevel
|
||||
from llama_stack.apis.shields import Shield
|
||||
from llama_stack.distribution.library_client import convert_pydantic_to_json_value
|
||||
from llama_stack.providers.datatypes import ShieldsProtocolPrivate
|
||||
|
||||
from .config import NVIDIASafetyConfig
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class NVIDIASafetyAdapter(Safety, ShieldsProtocolPrivate):
|
||||
def __init__(self, config: NVIDIASafetyConfig) -> None:
|
||||
"""
|
||||
Initialize the NVIDIASafetyAdapter with a given safety configuration.
|
||||
|
||||
Args:
|
||||
config (NVIDIASafetyConfig): The configuration containing the guardrails service URL and config ID.
|
||||
"""
|
||||
print(f"Initializing NVIDIASafetyAdapter({config.guardrails_service_url})...")
|
||||
self.config = config
|
||||
|
||||
async def initialize(self) -> None:
|
||||
pass
|
||||
|
||||
async def shutdown(self) -> None:
|
||||
pass
|
||||
|
||||
async def register_shield(self, shield: Shield) -> None:
|
||||
if not shield.provider_resource_id:
|
||||
raise ValueError("Shield model not provided.")
|
||||
|
||||
async def run_shield(
|
||||
self, shield_id: str, messages: List[Message], params: Optional[dict[str, Any]] = None
|
||||
) -> RunShieldResponse:
|
||||
"""
|
||||
Run a safety shield check against the provided messages.
|
||||
|
||||
Args:
|
||||
shield_id (str): The unique identifier for the shield to be used.
|
||||
messages (List[Message]): A list of Message objects representing the conversation history.
|
||||
params (Optional[dict[str, Any]]): Additional parameters for the shield check.
|
||||
|
||||
Returns:
|
||||
RunShieldResponse: The response containing safety violation details if any.
|
||||
|
||||
Raises:
|
||||
ValueError: If the shield with the provided shield_id is not found.
|
||||
"""
|
||||
shield = await self.shield_store.get_shield(shield_id)
|
||||
if not shield:
|
||||
raise ValueError(f"Shield {shield_id} not found")
|
||||
|
||||
self.shield = NeMoGuardrails(self.config, shield.shield_id)
|
||||
return await self.shield.run(messages)
|
||||
|
||||
|
||||
class NeMoGuardrails:
|
||||
"""
|
||||
A class that encapsulates NVIDIA's guardrails safety logic.
|
||||
|
||||
Sends messages to the guardrails service and interprets the response to determine
|
||||
if a safety violation has occurred.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: NVIDIASafetyConfig,
|
||||
model: str,
|
||||
threshold: float = 0.9,
|
||||
temperature: float = 1.0,
|
||||
):
|
||||
"""
|
||||
Initialize a NeMoGuardrails instance with the provided parameters.
|
||||
|
||||
Args:
|
||||
config (NVIDIASafetyConfig): The safety configuration containing the config ID and guardrails URL.
|
||||
model (str): The identifier or name of the model to be used for safety checks.
|
||||
threshold (float, optional): The threshold for flagging violations. Defaults to 0.9.
|
||||
temperature (float, optional): The temperature setting for the underlying model. Must be greater than 0. Defaults to 1.0.
|
||||
|
||||
Raises:
|
||||
ValueError: If temperature is less than or equal to 0.
|
||||
AssertionError: If config_id is not provided in the configuration.
|
||||
"""
|
||||
self.config_id = config.config_id
|
||||
self.model = model
|
||||
assert self.config_id is not None, "Must provide config id"
|
||||
if temperature <= 0:
|
||||
raise ValueError("Temperature must be greater than 0")
|
||||
|
||||
self.temperature = temperature
|
||||
self.threshold = threshold
|
||||
self.guardrails_service_url = config.guardrails_service_url
|
||||
|
||||
async def run(self, messages: List[Message]) -> RunShieldResponse:
|
||||
"""
|
||||
Queries the /v1/guardrails/checks endpoint of the NeMo guardrails deployed API.
|
||||
|
||||
Args:
|
||||
messages (List[Message]): A list of Message objects to be checked for safety violations.
|
||||
|
||||
Returns:
|
||||
RunShieldResponse: If the response indicates a violation ("blocked" status), returns a
|
||||
RunShieldResponse with a SafetyViolation; otherwise, returns a RunShieldResponse with violation set to None.
|
||||
|
||||
Raises:
|
||||
requests.HTTPError: If the POST request fails.
|
||||
"""
|
||||
headers = {
|
||||
"Accept": "application/json",
|
||||
}
|
||||
request_data = {
|
||||
"model": self.model,
|
||||
"messages": convert_pydantic_to_json_value(messages),
|
||||
"temperature": self.temperature,
|
||||
"top_p": 1,
|
||||
"frequency_penalty": 0,
|
||||
"presence_penalty": 0,
|
||||
"max_tokens": 160,
|
||||
"stream": False,
|
||||
"guardrails": {
|
||||
"config_id": self.config_id,
|
||||
},
|
||||
}
|
||||
response = requests.post(
|
||||
url=f"{self.guardrails_service_url}/v1/guardrail/checks", headers=headers, json=request_data
|
||||
)
|
||||
response.raise_for_status()
|
||||
if "Content-Type" in response.headers and response.headers["Content-Type"].startswith("application/json"):
|
||||
response_json = response.json()
|
||||
if response_json["status"] == "blocked":
|
||||
user_message = "Sorry I cannot do this."
|
||||
metadata = response_json["rails_status"]
|
||||
|
||||
return RunShieldResponse(
|
||||
violation=SafetyViolation(
|
||||
user_message=user_message,
|
||||
violation_level=ViolationLevel.ERROR,
|
||||
metadata=metadata,
|
||||
)
|
||||
)
|
||||
return RunShieldResponse(violation=None)
|
||||
|
|
@ -1,17 +0,0 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import Any
|
||||
|
||||
from .config import SampleConfig
|
||||
|
||||
|
||||
async def get_adapter_impl(config: SampleConfig, _deps) -> Any:
|
||||
from .sample import SampleSafetyImpl
|
||||
|
||||
impl = SampleSafetyImpl(config)
|
||||
await impl.initialize()
|
||||
return impl
|
||||
Some files were not shown because too many files have changed in this diff Show more
Loading…
Add table
Add a link
Reference in a new issue