mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-08-05 18:22:41 +00:00
agents to use tools api
This commit is contained in:
parent
596afc6497
commit
f90e9c2003
21 changed files with 538 additions and 329 deletions
|
@ -14,18 +14,16 @@ from typing import (
|
||||||
Literal,
|
Literal,
|
||||||
Optional,
|
Optional,
|
||||||
Protocol,
|
Protocol,
|
||||||
runtime_checkable,
|
|
||||||
Union,
|
Union,
|
||||||
|
runtime_checkable,
|
||||||
)
|
)
|
||||||
|
|
||||||
from llama_models.llama3.api.datatypes import ToolParamDefinition
|
from llama_models.llama3.api.datatypes import ToolParamDefinition
|
||||||
|
|
||||||
from llama_models.schema_utils import json_schema_type, webmethod
|
from llama_models.schema_utils import json_schema_type, webmethod
|
||||||
|
|
||||||
from pydantic import BaseModel, ConfigDict, Field
|
from pydantic import BaseModel, ConfigDict, Field
|
||||||
from typing_extensions import Annotated
|
from typing_extensions import Annotated
|
||||||
|
|
||||||
from llama_stack.apis.common.content_types import InterleavedContent, URL
|
from llama_stack.apis.common.content_types import URL, InterleavedContent
|
||||||
from llama_stack.apis.common.deployment_types import RestAPIExecutionConfig
|
from llama_stack.apis.common.deployment_types import RestAPIExecutionConfig
|
||||||
from llama_stack.apis.inference import (
|
from llama_stack.apis.inference import (
|
||||||
CompletionMessage,
|
CompletionMessage,
|
||||||
|
@ -40,7 +38,6 @@ from llama_stack.apis.inference import (
|
||||||
)
|
)
|
||||||
from llama_stack.apis.memory import MemoryBank
|
from llama_stack.apis.memory import MemoryBank
|
||||||
from llama_stack.apis.safety import SafetyViolation
|
from llama_stack.apis.safety import SafetyViolation
|
||||||
|
|
||||||
from llama_stack.providers.utils.telemetry.trace_protocol import trace_protocol
|
from llama_stack.providers.utils.telemetry.trace_protocol import trace_protocol
|
||||||
|
|
||||||
|
|
||||||
|
@ -110,85 +107,6 @@ class FunctionCallToolDefinition(ToolDefinitionCommon):
|
||||||
remote_execution: Optional[RestAPIExecutionConfig] = None
|
remote_execution: Optional[RestAPIExecutionConfig] = None
|
||||||
|
|
||||||
|
|
||||||
class _MemoryBankConfigCommon(BaseModel):
|
|
||||||
bank_id: str
|
|
||||||
|
|
||||||
|
|
||||||
class AgentVectorMemoryBankConfig(_MemoryBankConfigCommon):
|
|
||||||
type: Literal["vector"] = "vector"
|
|
||||||
|
|
||||||
|
|
||||||
class AgentKeyValueMemoryBankConfig(_MemoryBankConfigCommon):
|
|
||||||
type: Literal["keyvalue"] = "keyvalue"
|
|
||||||
keys: List[str] # what keys to focus on
|
|
||||||
|
|
||||||
|
|
||||||
class AgentKeywordMemoryBankConfig(_MemoryBankConfigCommon):
|
|
||||||
type: Literal["keyword"] = "keyword"
|
|
||||||
|
|
||||||
|
|
||||||
class AgentGraphMemoryBankConfig(_MemoryBankConfigCommon):
|
|
||||||
type: Literal["graph"] = "graph"
|
|
||||||
entities: List[str] # what entities to focus on
|
|
||||||
|
|
||||||
|
|
||||||
MemoryBankConfig = Annotated[
|
|
||||||
Union[
|
|
||||||
AgentVectorMemoryBankConfig,
|
|
||||||
AgentKeyValueMemoryBankConfig,
|
|
||||||
AgentKeywordMemoryBankConfig,
|
|
||||||
AgentGraphMemoryBankConfig,
|
|
||||||
],
|
|
||||||
Field(discriminator="type"),
|
|
||||||
]
|
|
||||||
|
|
||||||
|
|
||||||
class MemoryQueryGenerator(Enum):
|
|
||||||
default = "default"
|
|
||||||
llm = "llm"
|
|
||||||
custom = "custom"
|
|
||||||
|
|
||||||
|
|
||||||
class DefaultMemoryQueryGeneratorConfig(BaseModel):
|
|
||||||
type: Literal[MemoryQueryGenerator.default.value] = (
|
|
||||||
MemoryQueryGenerator.default.value
|
|
||||||
)
|
|
||||||
sep: str = " "
|
|
||||||
|
|
||||||
|
|
||||||
class LLMMemoryQueryGeneratorConfig(BaseModel):
|
|
||||||
type: Literal[MemoryQueryGenerator.llm.value] = MemoryQueryGenerator.llm.value
|
|
||||||
model: str
|
|
||||||
template: str
|
|
||||||
|
|
||||||
|
|
||||||
class CustomMemoryQueryGeneratorConfig(BaseModel):
|
|
||||||
type: Literal[MemoryQueryGenerator.custom.value] = MemoryQueryGenerator.custom.value
|
|
||||||
|
|
||||||
|
|
||||||
MemoryQueryGeneratorConfig = Annotated[
|
|
||||||
Union[
|
|
||||||
DefaultMemoryQueryGeneratorConfig,
|
|
||||||
LLMMemoryQueryGeneratorConfig,
|
|
||||||
CustomMemoryQueryGeneratorConfig,
|
|
||||||
],
|
|
||||||
Field(discriminator="type"),
|
|
||||||
]
|
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
|
||||||
class MemoryToolDefinition(ToolDefinitionCommon):
|
|
||||||
type: Literal[AgentTool.memory.value] = AgentTool.memory.value
|
|
||||||
memory_bank_configs: List[MemoryBankConfig] = Field(default_factory=list)
|
|
||||||
# This config defines how a query is generated using the messages
|
|
||||||
# for memory bank retrieval.
|
|
||||||
query_generator_config: MemoryQueryGeneratorConfig = Field(
|
|
||||||
default=DefaultMemoryQueryGeneratorConfig()
|
|
||||||
)
|
|
||||||
max_tokens_in_context: int = 4096
|
|
||||||
max_chunks: int = 10
|
|
||||||
|
|
||||||
|
|
||||||
AgentToolDefinition = Annotated[
|
AgentToolDefinition = Annotated[
|
||||||
Union[
|
Union[
|
||||||
SearchToolDefinition,
|
SearchToolDefinition,
|
||||||
|
@ -196,7 +114,6 @@ AgentToolDefinition = Annotated[
|
||||||
PhotogenToolDefinition,
|
PhotogenToolDefinition,
|
||||||
CodeInterpreterToolDefinition,
|
CodeInterpreterToolDefinition,
|
||||||
FunctionCallToolDefinition,
|
FunctionCallToolDefinition,
|
||||||
MemoryToolDefinition,
|
|
||||||
],
|
],
|
||||||
Field(discriminator="type"),
|
Field(discriminator="type"),
|
||||||
]
|
]
|
||||||
|
@ -295,7 +212,11 @@ class AgentConfigCommon(BaseModel):
|
||||||
input_shields: Optional[List[str]] = Field(default_factory=list)
|
input_shields: Optional[List[str]] = Field(default_factory=list)
|
||||||
output_shields: Optional[List[str]] = Field(default_factory=list)
|
output_shields: Optional[List[str]] = Field(default_factory=list)
|
||||||
|
|
||||||
tools: Optional[List[AgentToolDefinition]] = Field(default_factory=list)
|
tools: Optional[List[AgentToolDefinition]] = Field(
|
||||||
|
default_factory=list, deprecated=True
|
||||||
|
)
|
||||||
|
available_tools: Optional[List[str]] = Field(default_factory=list)
|
||||||
|
preprocessing_tools: Optional[List[str]] = Field(default_factory=list)
|
||||||
tool_choice: Optional[ToolChoice] = Field(default=ToolChoice.auto)
|
tool_choice: Optional[ToolChoice] = Field(default=ToolChoice.auto)
|
||||||
tool_prompt_format: Optional[ToolPromptFormat] = Field(
|
tool_prompt_format: Optional[ToolPromptFormat] = Field(
|
||||||
default=ToolPromptFormat.json
|
default=ToolPromptFormat.json
|
||||||
|
|
|
@ -68,10 +68,16 @@ ToolGroupDef = register_schema(
|
||||||
Annotated[
|
Annotated[
|
||||||
Union[MCPToolGroupDef, UserDefinedToolGroupDef], Field(discriminator="type")
|
Union[MCPToolGroupDef, UserDefinedToolGroupDef], Field(discriminator="type")
|
||||||
],
|
],
|
||||||
name="ToolGroup",
|
name="ToolGroupDef",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class ToolGroupInput(BaseModel):
|
||||||
|
tool_group_id: str
|
||||||
|
tool_group: ToolGroupDef
|
||||||
|
provider_id: Optional[str] = None
|
||||||
|
|
||||||
|
|
||||||
class ToolGroup(Resource):
|
class ToolGroup(Resource):
|
||||||
type: Literal[ResourceType.tool_group.value] = ResourceType.tool_group.value
|
type: Literal[ResourceType.tool_group.value] = ResourceType.tool_group.value
|
||||||
|
|
||||||
|
|
|
@ -161,6 +161,7 @@ a default SQLite store will be used.""",
|
||||||
datasets: List[DatasetInput] = Field(default_factory=list)
|
datasets: List[DatasetInput] = Field(default_factory=list)
|
||||||
scoring_fns: List[ScoringFnInput] = Field(default_factory=list)
|
scoring_fns: List[ScoringFnInput] = Field(default_factory=list)
|
||||||
eval_tasks: List[EvalTaskInput] = Field(default_factory=list)
|
eval_tasks: List[EvalTaskInput] = Field(default_factory=list)
|
||||||
|
tool_groups: List[ToolGroupInput] = Field(default_factory=list)
|
||||||
|
|
||||||
|
|
||||||
class BuildConfig(BaseModel):
|
class BuildConfig(BaseModel):
|
||||||
|
|
|
@ -5,9 +5,7 @@
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
import importlib
|
import importlib
|
||||||
import inspect
|
import inspect
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
|
|
||||||
from typing import Any, Dict, List, Set
|
from typing import Any, Dict, List, Set
|
||||||
|
|
||||||
from llama_stack.apis.agents import Agents
|
from llama_stack.apis.agents import Agents
|
||||||
|
@ -28,7 +26,6 @@ from llama_stack.apis.shields import Shields
|
||||||
from llama_stack.apis.telemetry import Telemetry
|
from llama_stack.apis.telemetry import Telemetry
|
||||||
from llama_stack.apis.tools import ToolGroups, ToolRuntime
|
from llama_stack.apis.tools import ToolGroups, ToolRuntime
|
||||||
from llama_stack.distribution.client import get_client_impl
|
from llama_stack.distribution.client import get_client_impl
|
||||||
|
|
||||||
from llama_stack.distribution.datatypes import (
|
from llama_stack.distribution.datatypes import (
|
||||||
AutoRoutedProviderSpec,
|
AutoRoutedProviderSpec,
|
||||||
Provider,
|
Provider,
|
||||||
|
@ -38,7 +35,7 @@ from llama_stack.distribution.datatypes import (
|
||||||
from llama_stack.distribution.distribution import builtin_automatically_routed_apis
|
from llama_stack.distribution.distribution import builtin_automatically_routed_apis
|
||||||
from llama_stack.distribution.store import DistributionRegistry
|
from llama_stack.distribution.store import DistributionRegistry
|
||||||
from llama_stack.distribution.utils.dynamic import instantiate_class_type
|
from llama_stack.distribution.utils.dynamic import instantiate_class_type
|
||||||
|
from llama_stack.providers.datatypes import * # noqa: F403
|
||||||
from llama_stack.providers.datatypes import (
|
from llama_stack.providers.datatypes import (
|
||||||
Api,
|
Api,
|
||||||
DatasetsProtocolPrivate,
|
DatasetsProtocolPrivate,
|
||||||
|
|
|
@ -523,6 +523,8 @@ class ToolGroupsRoutingTable(CommonRoutingTableImpl, ToolGroups):
|
||||||
)
|
)
|
||||||
provider_id = list(self.impls_by_provider_id.keys())[0]
|
provider_id = list(self.impls_by_provider_id.keys())[0]
|
||||||
|
|
||||||
|
# parse tool group to the type if dict
|
||||||
|
tool_group = parse_obj_as(ToolGroupDef, tool_group)
|
||||||
if isinstance(tool_group, MCPToolGroupDef):
|
if isinstance(tool_group, MCPToolGroupDef):
|
||||||
tool_defs = await self.impls_by_provider_id[provider_id].discover_tools(
|
tool_defs = await self.impls_by_provider_id[provider_id].discover_tools(
|
||||||
tool_group
|
tool_group
|
||||||
|
|
|
@ -12,7 +12,7 @@ from typing import Any, Dict, Optional
|
||||||
|
|
||||||
import pkg_resources
|
import pkg_resources
|
||||||
import yaml
|
import yaml
|
||||||
|
from llama_models.llama3.api.datatypes import * # noqa: F403
|
||||||
from termcolor import colored
|
from termcolor import colored
|
||||||
|
|
||||||
from llama_stack.apis.agents import Agents
|
from llama_stack.apis.agents import Agents
|
||||||
|
@ -33,14 +33,12 @@ from llama_stack.apis.scoring_functions import ScoringFunctions
|
||||||
from llama_stack.apis.shields import Shields
|
from llama_stack.apis.shields import Shields
|
||||||
from llama_stack.apis.synthetic_data_generation import SyntheticDataGeneration
|
from llama_stack.apis.synthetic_data_generation import SyntheticDataGeneration
|
||||||
from llama_stack.apis.telemetry import Telemetry
|
from llama_stack.apis.telemetry import Telemetry
|
||||||
|
|
||||||
from llama_stack.distribution.datatypes import StackRunConfig
|
from llama_stack.distribution.datatypes import StackRunConfig
|
||||||
from llama_stack.distribution.distribution import get_provider_registry
|
from llama_stack.distribution.distribution import get_provider_registry
|
||||||
from llama_stack.distribution.resolver import ProviderRegistry, resolve_impls
|
from llama_stack.distribution.resolver import ProviderRegistry, resolve_impls
|
||||||
from llama_stack.distribution.store.registry import create_dist_registry
|
from llama_stack.distribution.store.registry import create_dist_registry
|
||||||
from llama_stack.providers.datatypes import Api
|
from llama_stack.providers.datatypes import Api
|
||||||
|
|
||||||
|
|
||||||
log = logging.getLogger(__name__)
|
log = logging.getLogger(__name__)
|
||||||
|
|
||||||
LLAMA_STACK_API_VERSION = "alpha"
|
LLAMA_STACK_API_VERSION = "alpha"
|
||||||
|
@ -81,6 +79,7 @@ RESOURCES = [
|
||||||
"list_scoring_functions",
|
"list_scoring_functions",
|
||||||
),
|
),
|
||||||
("eval_tasks", Api.eval_tasks, "register_eval_task", "list_eval_tasks"),
|
("eval_tasks", Api.eval_tasks, "register_eval_task", "list_eval_tasks"),
|
||||||
|
("tool_groups", Api.tool_groups, "register_tool_group", "list_tool_groups"),
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -0,0 +1,5 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
|
@ -22,6 +22,8 @@ async def get_provider_impl(
|
||||||
deps[Api.memory],
|
deps[Api.memory],
|
||||||
deps[Api.safety],
|
deps[Api.safety],
|
||||||
deps[Api.memory_banks],
|
deps[Api.memory_banks],
|
||||||
|
deps[Api.tool_runtime],
|
||||||
|
deps[Api.tool_groups],
|
||||||
)
|
)
|
||||||
await impl.initialize()
|
await impl.initialize()
|
||||||
return impl
|
return impl
|
||||||
|
|
|
@ -4,25 +4,21 @@
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
import asyncio
|
|
||||||
import copy
|
import copy
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
import re
|
|
||||||
import secrets
|
import secrets
|
||||||
import string
|
import string
|
||||||
import uuid
|
import uuid
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from typing import AsyncGenerator, Dict, List, Optional, Tuple
|
from typing import AsyncGenerator, Dict, List
|
||||||
from urllib.parse import urlparse
|
from urllib.parse import urlparse
|
||||||
|
|
||||||
import httpx
|
import httpx
|
||||||
|
|
||||||
from llama_models.llama3.api.datatypes import BuiltinTool
|
from llama_models.llama3.api.datatypes import BuiltinTool
|
||||||
|
|
||||||
from llama_stack.apis.agents import (
|
from llama_stack.apis.agents import (
|
||||||
AgentConfig,
|
AgentConfig,
|
||||||
AgentTool,
|
|
||||||
AgentTurnCreateRequest,
|
AgentTurnCreateRequest,
|
||||||
AgentTurnResponseEvent,
|
AgentTurnResponseEvent,
|
||||||
AgentTurnResponseEventType,
|
AgentTurnResponseEventType,
|
||||||
|
@ -36,8 +32,6 @@ from llama_stack.apis.agents import (
|
||||||
CodeInterpreterToolDefinition,
|
CodeInterpreterToolDefinition,
|
||||||
FunctionCallToolDefinition,
|
FunctionCallToolDefinition,
|
||||||
InferenceStep,
|
InferenceStep,
|
||||||
MemoryRetrievalStep,
|
|
||||||
MemoryToolDefinition,
|
|
||||||
PhotogenToolDefinition,
|
PhotogenToolDefinition,
|
||||||
SearchToolDefinition,
|
SearchToolDefinition,
|
||||||
ShieldCallStep,
|
ShieldCallStep,
|
||||||
|
@ -46,11 +40,9 @@ from llama_stack.apis.agents import (
|
||||||
Turn,
|
Turn,
|
||||||
WolframAlphaToolDefinition,
|
WolframAlphaToolDefinition,
|
||||||
)
|
)
|
||||||
|
|
||||||
from llama_stack.apis.common.content_types import (
|
from llama_stack.apis.common.content_types import (
|
||||||
InterleavedContent,
|
|
||||||
TextContentItem,
|
|
||||||
URL,
|
URL,
|
||||||
|
TextContentItem,
|
||||||
)
|
)
|
||||||
from llama_stack.apis.inference import (
|
from llama_stack.apis.inference import (
|
||||||
ChatCompletionResponseEventType,
|
ChatCompletionResponseEventType,
|
||||||
|
@ -62,30 +54,26 @@ from llama_stack.apis.inference import (
|
||||||
SystemMessage,
|
SystemMessage,
|
||||||
ToolCallDelta,
|
ToolCallDelta,
|
||||||
ToolCallParseStatus,
|
ToolCallParseStatus,
|
||||||
ToolChoice,
|
|
||||||
ToolDefinition,
|
ToolDefinition,
|
||||||
ToolResponse,
|
ToolResponse,
|
||||||
ToolResponseMessage,
|
ToolResponseMessage,
|
||||||
UserMessage,
|
UserMessage,
|
||||||
)
|
)
|
||||||
from llama_stack.apis.memory import Memory, MemoryBankDocument, QueryDocumentsResponse
|
from llama_stack.apis.memory import Memory
|
||||||
from llama_stack.apis.memory_banks import MemoryBanks, VectorMemoryBankParams
|
from llama_stack.apis.memory_banks import MemoryBanks
|
||||||
from llama_stack.apis.safety import Safety
|
from llama_stack.apis.safety import Safety
|
||||||
|
|
||||||
from llama_stack.providers.utils.kvstore import KVStore
|
from llama_stack.providers.utils.kvstore import KVStore
|
||||||
from llama_stack.providers.utils.memory.vector_store import concat_interleaved_content
|
|
||||||
from llama_stack.providers.utils.telemetry import tracing
|
from llama_stack.providers.utils.telemetry import tracing
|
||||||
|
|
||||||
from .persistence import AgentPersistence
|
from .persistence import AgentPersistence
|
||||||
from .rag.context_retriever import generate_rag_query
|
|
||||||
from .safety import SafetyException, ShieldRunnerMixin
|
from .safety import SafetyException, ShieldRunnerMixin
|
||||||
from .tools.base import BaseTool
|
from .tools.base import BaseTool
|
||||||
from .tools.builtin import (
|
from .tools.builtin import (
|
||||||
CodeInterpreterTool,
|
CodeInterpreterTool,
|
||||||
interpret_content_as_attachment,
|
|
||||||
PhotogenTool,
|
PhotogenTool,
|
||||||
SearchTool,
|
SearchTool,
|
||||||
WolframAlphaTool,
|
WolframAlphaTool,
|
||||||
|
interpret_content_as_attachment,
|
||||||
)
|
)
|
||||||
from .tools.safety import SafeTool
|
from .tools.safety import SafeTool
|
||||||
|
|
||||||
|
@ -108,6 +96,8 @@ class ChatAgent(ShieldRunnerMixin):
|
||||||
memory_api: Memory,
|
memory_api: Memory,
|
||||||
memory_banks_api: MemoryBanks,
|
memory_banks_api: MemoryBanks,
|
||||||
safety_api: Safety,
|
safety_api: Safety,
|
||||||
|
tool_runtime_api: ToolRuntime,
|
||||||
|
tool_groups_api: ToolGroups,
|
||||||
persistence_store: KVStore,
|
persistence_store: KVStore,
|
||||||
):
|
):
|
||||||
self.agent_id = agent_id
|
self.agent_id = agent_id
|
||||||
|
@ -118,6 +108,8 @@ class ChatAgent(ShieldRunnerMixin):
|
||||||
self.memory_banks_api = memory_banks_api
|
self.memory_banks_api = memory_banks_api
|
||||||
self.safety_api = safety_api
|
self.safety_api = safety_api
|
||||||
self.storage = AgentPersistence(agent_id, persistence_store)
|
self.storage = AgentPersistence(agent_id, persistence_store)
|
||||||
|
self.tool_runtime_api = tool_runtime_api
|
||||||
|
self.tool_groups_api = tool_groups_api
|
||||||
|
|
||||||
builtin_tools = []
|
builtin_tools = []
|
||||||
for tool_defn in agent_config.tools:
|
for tool_defn in agent_config.tools:
|
||||||
|
@ -392,62 +384,50 @@ class ChatAgent(ShieldRunnerMixin):
|
||||||
sampling_params: SamplingParams,
|
sampling_params: SamplingParams,
|
||||||
stream: bool = False,
|
stream: bool = False,
|
||||||
) -> AsyncGenerator:
|
) -> AsyncGenerator:
|
||||||
enabled_tools = set(t.type for t in self.agent_config.tools)
|
if self.agent_config.preprocessing_tools:
|
||||||
need_rag_context = await self._should_retrieve_context(
|
with tracing.span("preprocessing_tools") as span:
|
||||||
input_messages, attachments
|
for tool_name in self.agent_config.preprocessing_tools:
|
||||||
)
|
yield AgentTurnResponseStreamChunk(
|
||||||
if need_rag_context:
|
event=AgentTurnResponseEvent(
|
||||||
step_id = str(uuid.uuid4())
|
payload=AgentTurnResponseStepStartPayload(
|
||||||
yield AgentTurnResponseStreamChunk(
|
step_type=StepType.tool_execution.value,
|
||||||
event=AgentTurnResponseEvent(
|
step_id=str(uuid.uuid4()),
|
||||||
payload=AgentTurnResponseStepStartPayload(
|
)
|
||||||
step_type=StepType.memory_retrieval.value,
|
)
|
||||||
step_id=step_id,
|
|
||||||
)
|
)
|
||||||
)
|
args = dict(
|
||||||
)
|
session_id=session_id,
|
||||||
|
input_messages=input_messages,
|
||||||
# TODO: find older context from the session and either replace it
|
attachments=attachments,
|
||||||
# or append with a sliding window. this is really a very simplistic implementation
|
|
||||||
with tracing.span("retrieve_rag_context") as span:
|
|
||||||
rag_context, bank_ids = await self._retrieve_context(
|
|
||||||
session_id, input_messages, attachments
|
|
||||||
)
|
|
||||||
span.set_attribute(
|
|
||||||
"input", [m.model_dump_json() for m in input_messages]
|
|
||||||
)
|
|
||||||
span.set_attribute("output", rag_context)
|
|
||||||
span.set_attribute("bank_ids", bank_ids)
|
|
||||||
|
|
||||||
step_id = str(uuid.uuid4())
|
|
||||||
yield AgentTurnResponseStreamChunk(
|
|
||||||
event=AgentTurnResponseEvent(
|
|
||||||
payload=AgentTurnResponseStepCompletePayload(
|
|
||||||
step_type=StepType.memory_retrieval.value,
|
|
||||||
step_id=step_id,
|
|
||||||
step_details=MemoryRetrievalStep(
|
|
||||||
turn_id=turn_id,
|
|
||||||
step_id=step_id,
|
|
||||||
memory_bank_ids=bank_ids,
|
|
||||||
inserted_context=rag_context or "",
|
|
||||||
),
|
|
||||||
)
|
)
|
||||||
)
|
result = await self.tool_runtime_api.invoke_tool(
|
||||||
)
|
tool_name=tool_name,
|
||||||
|
args=args,
|
||||||
if rag_context:
|
)
|
||||||
last_message = input_messages[-1]
|
yield AgentTurnResponseStreamChunk(
|
||||||
last_message.context = rag_context
|
event=AgentTurnResponseEvent(
|
||||||
|
payload=AgentTurnResponseStepProgressPayload(
|
||||||
elif attachments and AgentTool.code_interpreter.value in enabled_tools:
|
step_type=StepType.tool_execution.value,
|
||||||
urls = [a.content for a in attachments if isinstance(a.content, URL)]
|
step_id=str(uuid.uuid4()),
|
||||||
# TODO: we need to migrate URL away from str type
|
tool_call_delta=ToolCallDelta(
|
||||||
pattern = re.compile("^(https?://|file://|data:)")
|
parse_status=ToolCallParseStatus.success,
|
||||||
urls += [
|
content=ToolCall(
|
||||||
URL(uri=a.content) for a in attachments if pattern.match(a.content)
|
call_id="", tool_name=tool_name, arguments={}
|
||||||
]
|
),
|
||||||
msg = await attachment_message(self.tempdir, urls)
|
),
|
||||||
input_messages.append(msg)
|
)
|
||||||
|
)
|
||||||
|
)
|
||||||
|
span.set_attribute(
|
||||||
|
"input", [m.model_dump_json() for m in input_messages]
|
||||||
|
)
|
||||||
|
span.set_attribute("output", result.content)
|
||||||
|
span.set_attribute("error_code", result.error_code)
|
||||||
|
span.set_attribute("error_message", result.error_message)
|
||||||
|
span.set_attribute("tool_name", tool_name)
|
||||||
|
if result.error_code != 0 and result.content:
|
||||||
|
last_message = input_messages[-1]
|
||||||
|
last_message.context = result.content
|
||||||
|
|
||||||
output_attachments = []
|
output_attachments = []
|
||||||
|
|
||||||
|
@ -659,129 +639,6 @@ class ChatAgent(ShieldRunnerMixin):
|
||||||
|
|
||||||
n_iter += 1
|
n_iter += 1
|
||||||
|
|
||||||
async def _ensure_memory_bank(self, session_id: str) -> str:
|
|
||||||
session_info = await self.storage.get_session_info(session_id)
|
|
||||||
if session_info is None:
|
|
||||||
raise ValueError(f"Session {session_id} not found")
|
|
||||||
|
|
||||||
if session_info.memory_bank_id is None:
|
|
||||||
bank_id = f"memory_bank_{session_id}"
|
|
||||||
await self.memory_banks_api.register_memory_bank(
|
|
||||||
memory_bank_id=bank_id,
|
|
||||||
params=VectorMemoryBankParams(
|
|
||||||
embedding_model="all-MiniLM-L6-v2",
|
|
||||||
chunk_size_in_tokens=512,
|
|
||||||
),
|
|
||||||
)
|
|
||||||
await self.storage.add_memory_bank_to_session(session_id, bank_id)
|
|
||||||
else:
|
|
||||||
bank_id = session_info.memory_bank_id
|
|
||||||
|
|
||||||
return bank_id
|
|
||||||
|
|
||||||
async def _should_retrieve_context(
|
|
||||||
self, messages: List[Message], attachments: List[Attachment]
|
|
||||||
) -> bool:
|
|
||||||
enabled_tools = set(t.type for t in self.agent_config.tools)
|
|
||||||
if attachments:
|
|
||||||
if (
|
|
||||||
AgentTool.code_interpreter.value in enabled_tools
|
|
||||||
and self.agent_config.tool_choice == ToolChoice.required
|
|
||||||
):
|
|
||||||
return False
|
|
||||||
else:
|
|
||||||
return True
|
|
||||||
|
|
||||||
return AgentTool.memory.value in enabled_tools
|
|
||||||
|
|
||||||
def _memory_tool_definition(self) -> Optional[MemoryToolDefinition]:
|
|
||||||
for t in self.agent_config.tools:
|
|
||||||
if t.type == AgentTool.memory.value:
|
|
||||||
return t
|
|
||||||
|
|
||||||
return None
|
|
||||||
|
|
||||||
async def _retrieve_context(
|
|
||||||
self, session_id: str, messages: List[Message], attachments: List[Attachment]
|
|
||||||
) -> Tuple[Optional[InterleavedContent], List[int]]: # (rag_context, bank_ids)
|
|
||||||
bank_ids = []
|
|
||||||
|
|
||||||
memory = self._memory_tool_definition()
|
|
||||||
assert memory is not None, "Memory tool not configured"
|
|
||||||
bank_ids.extend(c.bank_id for c in memory.memory_bank_configs)
|
|
||||||
|
|
||||||
if attachments:
|
|
||||||
bank_id = await self._ensure_memory_bank(session_id)
|
|
||||||
bank_ids.append(bank_id)
|
|
||||||
|
|
||||||
documents = [
|
|
||||||
MemoryBankDocument(
|
|
||||||
document_id=str(uuid.uuid4()),
|
|
||||||
content=a.content,
|
|
||||||
mime_type=a.mime_type,
|
|
||||||
metadata={},
|
|
||||||
)
|
|
||||||
for a in attachments
|
|
||||||
]
|
|
||||||
with tracing.span("insert_documents"):
|
|
||||||
await self.memory_api.insert_documents(bank_id, documents)
|
|
||||||
else:
|
|
||||||
session_info = await self.storage.get_session_info(session_id)
|
|
||||||
if session_info.memory_bank_id:
|
|
||||||
bank_ids.append(session_info.memory_bank_id)
|
|
||||||
|
|
||||||
if not bank_ids:
|
|
||||||
# this can happen if the per-session memory bank is not yet populated
|
|
||||||
# (i.e., no prior turns uploaded an Attachment)
|
|
||||||
return None, []
|
|
||||||
|
|
||||||
query = await generate_rag_query(
|
|
||||||
memory.query_generator_config, messages, inference_api=self.inference_api
|
|
||||||
)
|
|
||||||
tasks = [
|
|
||||||
self.memory_api.query_documents(
|
|
||||||
bank_id=bank_id,
|
|
||||||
query=query,
|
|
||||||
params={
|
|
||||||
"max_chunks": 5,
|
|
||||||
},
|
|
||||||
)
|
|
||||||
for bank_id in bank_ids
|
|
||||||
]
|
|
||||||
results: List[QueryDocumentsResponse] = await asyncio.gather(*tasks)
|
|
||||||
chunks = [c for r in results for c in r.chunks]
|
|
||||||
scores = [s for r in results for s in r.scores]
|
|
||||||
|
|
||||||
if not chunks:
|
|
||||||
return None, bank_ids
|
|
||||||
|
|
||||||
# sort by score
|
|
||||||
chunks, scores = zip(
|
|
||||||
*sorted(zip(chunks, scores), key=lambda x: x[1], reverse=True)
|
|
||||||
)
|
|
||||||
|
|
||||||
tokens = 0
|
|
||||||
picked = []
|
|
||||||
for c in chunks[: memory.max_chunks]:
|
|
||||||
tokens += c.token_count
|
|
||||||
if tokens > memory.max_tokens_in_context:
|
|
||||||
log.error(
|
|
||||||
f"Using {len(picked)} chunks; reached max tokens in context: {tokens}",
|
|
||||||
)
|
|
||||||
break
|
|
||||||
picked.append(f"id:{c.document_id}; content:{c.content}")
|
|
||||||
|
|
||||||
return (
|
|
||||||
concat_interleaved_content(
|
|
||||||
[
|
|
||||||
"Here are the retrieved documents for relevant context:\n=== START-RETRIEVED-CONTEXT ===\n",
|
|
||||||
*picked,
|
|
||||||
"\n=== END-RETRIEVED-CONTEXT ===\n",
|
|
||||||
]
|
|
||||||
),
|
|
||||||
bank_ids,
|
|
||||||
)
|
|
||||||
|
|
||||||
def _get_tools(self) -> List[ToolDefinition]:
|
def _get_tools(self) -> List[ToolDefinition]:
|
||||||
ret = []
|
ret = []
|
||||||
for t in self.agent_config.tools:
|
for t in self.agent_config.tools:
|
||||||
|
|
|
@ -24,12 +24,11 @@ from llama_stack.apis.agents import (
|
||||||
Session,
|
Session,
|
||||||
Turn,
|
Turn,
|
||||||
)
|
)
|
||||||
|
|
||||||
from llama_stack.apis.inference import Inference, ToolResponseMessage, UserMessage
|
from llama_stack.apis.inference import Inference, ToolResponseMessage, UserMessage
|
||||||
from llama_stack.apis.memory import Memory
|
from llama_stack.apis.memory import Memory
|
||||||
from llama_stack.apis.memory_banks import MemoryBanks
|
from llama_stack.apis.memory_banks import MemoryBanks
|
||||||
from llama_stack.apis.safety import Safety
|
from llama_stack.apis.safety import Safety
|
||||||
|
from llama_stack.apis.tools import ToolGroups, ToolRuntime
|
||||||
from llama_stack.providers.utils.kvstore import InmemoryKVStoreImpl, kvstore_impl
|
from llama_stack.providers.utils.kvstore import InmemoryKVStoreImpl, kvstore_impl
|
||||||
|
|
||||||
from .agent_instance import ChatAgent
|
from .agent_instance import ChatAgent
|
||||||
|
@ -47,12 +46,16 @@ class MetaReferenceAgentsImpl(Agents):
|
||||||
memory_api: Memory,
|
memory_api: Memory,
|
||||||
safety_api: Safety,
|
safety_api: Safety,
|
||||||
memory_banks_api: MemoryBanks,
|
memory_banks_api: MemoryBanks,
|
||||||
|
tool_runtime_api: ToolRuntime,
|
||||||
|
tool_groups_api: ToolGroups,
|
||||||
):
|
):
|
||||||
self.config = config
|
self.config = config
|
||||||
self.inference_api = inference_api
|
self.inference_api = inference_api
|
||||||
self.memory_api = memory_api
|
self.memory_api = memory_api
|
||||||
self.safety_api = safety_api
|
self.safety_api = safety_api
|
||||||
self.memory_banks_api = memory_banks_api
|
self.memory_banks_api = memory_banks_api
|
||||||
|
self.tool_runtime_api = tool_runtime_api
|
||||||
|
self.tool_groups_api = tool_groups_api
|
||||||
|
|
||||||
self.in_memory_store = InmemoryKVStoreImpl()
|
self.in_memory_store = InmemoryKVStoreImpl()
|
||||||
self.tempdir = tempfile.mkdtemp()
|
self.tempdir = tempfile.mkdtemp()
|
||||||
|
@ -112,6 +115,8 @@ class MetaReferenceAgentsImpl(Agents):
|
||||||
safety_api=self.safety_api,
|
safety_api=self.safety_api,
|
||||||
memory_api=self.memory_api,
|
memory_api=self.memory_api,
|
||||||
memory_banks_api=self.memory_banks_api,
|
memory_banks_api=self.memory_banks_api,
|
||||||
|
tool_runtime_api=self.tool_runtime_api,
|
||||||
|
tool_groups_api=self.tool_groups_api,
|
||||||
persistence_store=(
|
persistence_store=(
|
||||||
self.persistence_store
|
self.persistence_store
|
||||||
if agent_config.enable_session_persistence
|
if agent_config.enable_session_persistence
|
||||||
|
|
|
@ -8,13 +8,11 @@ import json
|
||||||
import logging
|
import logging
|
||||||
import uuid
|
import uuid
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
|
|
||||||
from typing import List, Optional
|
from typing import List, Optional
|
||||||
|
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
|
||||||
from llama_stack.apis.agents import Turn
|
from llama_stack.apis.agents import Turn
|
||||||
|
|
||||||
from llama_stack.providers.utils.kvstore import KVStore
|
from llama_stack.providers.utils.kvstore import KVStore
|
||||||
|
|
||||||
log = logging.getLogger(__name__)
|
log = logging.getLogger(__name__)
|
||||||
|
@ -23,7 +21,6 @@ log = logging.getLogger(__name__)
|
||||||
class AgentSessionInfo(BaseModel):
|
class AgentSessionInfo(BaseModel):
|
||||||
session_id: str
|
session_id: str
|
||||||
session_name: str
|
session_name: str
|
||||||
memory_bank_id: Optional[str] = None
|
|
||||||
started_at: datetime
|
started_at: datetime
|
||||||
|
|
||||||
|
|
||||||
|
@ -54,17 +51,6 @@ class AgentPersistence:
|
||||||
|
|
||||||
return AgentSessionInfo(**json.loads(value))
|
return AgentSessionInfo(**json.loads(value))
|
||||||
|
|
||||||
async def add_memory_bank_to_session(self, session_id: str, bank_id: str):
|
|
||||||
session_info = await self.get_session_info(session_id)
|
|
||||||
if session_info is None:
|
|
||||||
raise ValueError(f"Session {session_id} not found")
|
|
||||||
|
|
||||||
session_info.memory_bank_id = bank_id
|
|
||||||
await self.kvstore.set(
|
|
||||||
key=f"session:{self.agent_id}:{session_id}",
|
|
||||||
value=session_info.model_dump_json(),
|
|
||||||
)
|
|
||||||
|
|
||||||
async def add_turn_to_session(self, session_id: str, turn: Turn):
|
async def add_turn_to_session(self, session_id: str, turn: Turn):
|
||||||
await self.kvstore.set(
|
await self.kvstore.set(
|
||||||
key=f"session:{self.agent_id}:{session_id}:{turn.turn_id}",
|
key=f"session:{self.agent_id}:{session_id}:{turn.turn_id}",
|
||||||
|
|
20
llama_stack/providers/inline/tool_runtime/memory/__init__.py
Normal file
20
llama_stack/providers/inline/tool_runtime/memory/__init__.py
Normal file
|
@ -0,0 +1,20 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
from typing import Any, Dict
|
||||||
|
|
||||||
|
from llama_stack.providers.datatypes import Api
|
||||||
|
|
||||||
|
from .config import MemoryToolConfig
|
||||||
|
from .memory import MemoryToolRuntimeImpl
|
||||||
|
|
||||||
|
|
||||||
|
async def get_provider_impl(config: MemoryToolConfig, deps: Dict[str, Any]):
|
||||||
|
impl = MemoryToolRuntimeImpl(
|
||||||
|
config, deps[Api.memory], deps[Api.memory_banks], deps[Api.inference]
|
||||||
|
)
|
||||||
|
await impl.initialize()
|
||||||
|
return impl
|
93
llama_stack/providers/inline/tool_runtime/memory/config.py
Normal file
93
llama_stack/providers/inline/tool_runtime/memory/config.py
Normal file
|
@ -0,0 +1,93 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
from enum import Enum
|
||||||
|
from typing import Annotated, List, Literal, Union
|
||||||
|
|
||||||
|
from llama_stack.distribution.utils.config_dirs import RUNTIME_BASE_DIR
|
||||||
|
from llama_stack.providers.utils.kvstore import KVStoreConfig, SqliteKVStoreConfig
|
||||||
|
|
||||||
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
|
|
||||||
|
class _MemoryBankConfigCommon(BaseModel):
|
||||||
|
bank_id: str
|
||||||
|
|
||||||
|
|
||||||
|
class VectorMemoryBankConfig(_MemoryBankConfigCommon):
|
||||||
|
type: Literal["vector"] = "vector"
|
||||||
|
|
||||||
|
|
||||||
|
class KeyValueMemoryBankConfig(_MemoryBankConfigCommon):
|
||||||
|
type: Literal["keyvalue"] = "keyvalue"
|
||||||
|
keys: List[str] # what keys to focus on
|
||||||
|
|
||||||
|
|
||||||
|
class KeywordMemoryBankConfig(_MemoryBankConfigCommon):
|
||||||
|
type: Literal["keyword"] = "keyword"
|
||||||
|
|
||||||
|
|
||||||
|
class GraphMemoryBankConfig(_MemoryBankConfigCommon):
|
||||||
|
type: Literal["graph"] = "graph"
|
||||||
|
entities: List[str] # what entities to focus on
|
||||||
|
|
||||||
|
|
||||||
|
MemoryBankConfig = Annotated[
|
||||||
|
Union[
|
||||||
|
VectorMemoryBankConfig,
|
||||||
|
KeyValueMemoryBankConfig,
|
||||||
|
KeywordMemoryBankConfig,
|
||||||
|
GraphMemoryBankConfig,
|
||||||
|
],
|
||||||
|
Field(discriminator="type"),
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
class MemoryQueryGenerator(Enum):
|
||||||
|
default = "default"
|
||||||
|
llm = "llm"
|
||||||
|
custom = "custom"
|
||||||
|
|
||||||
|
|
||||||
|
class DefaultMemoryQueryGeneratorConfig(BaseModel):
|
||||||
|
type: Literal[MemoryQueryGenerator.default.value] = (
|
||||||
|
MemoryQueryGenerator.default.value
|
||||||
|
)
|
||||||
|
sep: str = " "
|
||||||
|
|
||||||
|
|
||||||
|
class LLMMemoryQueryGeneratorConfig(BaseModel):
|
||||||
|
type: Literal[MemoryQueryGenerator.llm.value] = MemoryQueryGenerator.llm.value
|
||||||
|
model: str
|
||||||
|
template: str
|
||||||
|
|
||||||
|
|
||||||
|
class CustomMemoryQueryGeneratorConfig(BaseModel):
|
||||||
|
type: Literal[MemoryQueryGenerator.custom.value] = MemoryQueryGenerator.custom.value
|
||||||
|
|
||||||
|
|
||||||
|
MemoryQueryGeneratorConfig = Annotated[
|
||||||
|
Union[
|
||||||
|
DefaultMemoryQueryGeneratorConfig,
|
||||||
|
LLMMemoryQueryGeneratorConfig,
|
||||||
|
CustomMemoryQueryGeneratorConfig,
|
||||||
|
],
|
||||||
|
Field(discriminator="type"),
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
class MemoryToolConfig(BaseModel):
|
||||||
|
memory_bank_configs: List[MemoryBankConfig] = Field(default_factory=list)
|
||||||
|
# This config defines how a query is generated using the messages
|
||||||
|
# for memory bank retrieval.
|
||||||
|
query_generator_config: MemoryQueryGeneratorConfig = Field(
|
||||||
|
default=DefaultMemoryQueryGeneratorConfig()
|
||||||
|
)
|
||||||
|
max_tokens_in_context: int = 4096
|
||||||
|
max_chunks: int = 10
|
||||||
|
kvstore_config: KVStoreConfig = SqliteKVStoreConfig(
|
||||||
|
db_path=(RUNTIME_BASE_DIR / "memory.db").as_posix()
|
||||||
|
)
|
|
@ -8,16 +8,17 @@ from typing import List
|
||||||
|
|
||||||
from jinja2 import Template
|
from jinja2 import Template
|
||||||
|
|
||||||
from llama_stack.apis.agents import (
|
from llama_stack.apis.inference import Message, UserMessage
|
||||||
|
from llama_stack.providers.utils.inference.prompt_adapter import (
|
||||||
|
interleaved_content_as_str,
|
||||||
|
)
|
||||||
|
|
||||||
|
from .config import (
|
||||||
DefaultMemoryQueryGeneratorConfig,
|
DefaultMemoryQueryGeneratorConfig,
|
||||||
LLMMemoryQueryGeneratorConfig,
|
LLMMemoryQueryGeneratorConfig,
|
||||||
MemoryQueryGenerator,
|
MemoryQueryGenerator,
|
||||||
MemoryQueryGeneratorConfig,
|
MemoryQueryGeneratorConfig,
|
||||||
)
|
)
|
||||||
from llama_stack.apis.inference import Message, UserMessage
|
|
||||||
from llama_stack.providers.utils.inference.prompt_adapter import (
|
|
||||||
interleaved_content_as_str,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
async def generate_rag_query(
|
async def generate_rag_query(
|
253
llama_stack/providers/inline/tool_runtime/memory/memory.py
Normal file
253
llama_stack/providers/inline/tool_runtime/memory/memory.py
Normal file
|
@ -0,0 +1,253 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
import os
|
||||||
|
import re
|
||||||
|
import secrets
|
||||||
|
import string
|
||||||
|
import tempfile
|
||||||
|
import uuid
|
||||||
|
from typing import Any, Dict, List, Optional
|
||||||
|
from urllib.parse import urlparse
|
||||||
|
|
||||||
|
import httpx
|
||||||
|
|
||||||
|
from llama_stack.apis.agents import Attachment
|
||||||
|
from llama_stack.apis.common.content_types import TextContentItem, URL
|
||||||
|
from llama_stack.apis.inference import Inference, InterleavedContent, Message
|
||||||
|
from llama_stack.apis.memory import Memory, MemoryBankDocument, QueryDocumentsResponse
|
||||||
|
from llama_stack.apis.memory_banks import MemoryBanks, VectorMemoryBankParams
|
||||||
|
from llama_stack.apis.tools import (
|
||||||
|
ToolDef,
|
||||||
|
ToolGroupDef,
|
||||||
|
ToolInvocationResult,
|
||||||
|
ToolRuntime,
|
||||||
|
)
|
||||||
|
from llama_stack.providers.datatypes import ToolsProtocolPrivate
|
||||||
|
from llama_stack.providers.utils.kvstore import kvstore_impl
|
||||||
|
from llama_stack.providers.utils.memory.vector_store import concat_interleaved_content
|
||||||
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
from .config import MemoryToolConfig
|
||||||
|
from .context_retriever import generate_rag_query
|
||||||
|
|
||||||
|
log = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class MemorySessionInfo(BaseModel):
|
||||||
|
session_id: str
|
||||||
|
session_name: str
|
||||||
|
memory_bank_id: Optional[str] = None
|
||||||
|
|
||||||
|
|
||||||
|
def make_random_string(length: int = 8):
|
||||||
|
return "".join(
|
||||||
|
secrets.choice(string.ascii_letters + string.digits) for _ in range(length)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class MemoryToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
config: MemoryToolConfig,
|
||||||
|
memory_api: Memory,
|
||||||
|
memory_banks_api: MemoryBanks,
|
||||||
|
inference_api: Inference,
|
||||||
|
):
|
||||||
|
self.config = config
|
||||||
|
self.memory_api = memory_api
|
||||||
|
self.memory_banks_api = memory_banks_api
|
||||||
|
self.tempdir = tempfile.mkdtemp()
|
||||||
|
self.inference_api = inference_api
|
||||||
|
|
||||||
|
async def initialize(self):
|
||||||
|
self.kvstore = await kvstore_impl(self.config.kvstore_config)
|
||||||
|
|
||||||
|
async def discover_tools(self, tool_group: ToolGroupDef) -> List[ToolDef]:
|
||||||
|
return []
|
||||||
|
|
||||||
|
async def create_session(self, session_id: str) -> MemorySessionInfo:
|
||||||
|
session_info = MemorySessionInfo(
|
||||||
|
session_id=session_id,
|
||||||
|
session_name=f"session_{session_id}",
|
||||||
|
)
|
||||||
|
await self.kvstore.set(
|
||||||
|
key=f"memory::session:{session_id}",
|
||||||
|
value=session_info.model_dump_json(),
|
||||||
|
)
|
||||||
|
return session_info
|
||||||
|
|
||||||
|
async def get_session_info(self, session_id: str) -> Optional[MemorySessionInfo]:
|
||||||
|
value = await self.kvstore.get(
|
||||||
|
key=f"memory::session:{session_id}",
|
||||||
|
)
|
||||||
|
if not value:
|
||||||
|
session_info = await self.create_session(session_id)
|
||||||
|
return session_info
|
||||||
|
|
||||||
|
return MemorySessionInfo(**json.loads(value))
|
||||||
|
|
||||||
|
async def add_memory_bank_to_session(self, session_id: str, bank_id: str):
|
||||||
|
session_info = await self.get_session_info(session_id)
|
||||||
|
|
||||||
|
session_info.memory_bank_id = bank_id
|
||||||
|
await self.kvstore.set(
|
||||||
|
key=f"memory::session:{session_id}",
|
||||||
|
value=session_info.model_dump_json(),
|
||||||
|
)
|
||||||
|
|
||||||
|
async def _ensure_memory_bank(self, session_id: str) -> str:
|
||||||
|
session_info = await self.get_session_info(session_id)
|
||||||
|
|
||||||
|
if session_info.memory_bank_id is None:
|
||||||
|
bank_id = f"memory_bank_{session_id}"
|
||||||
|
await self.memory_banks_api.register_memory_bank(
|
||||||
|
memory_bank_id=bank_id,
|
||||||
|
params=VectorMemoryBankParams(
|
||||||
|
embedding_model="all-MiniLM-L6-v2",
|
||||||
|
chunk_size_in_tokens=512,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
await self.add_memory_bank_to_session(session_id, bank_id)
|
||||||
|
else:
|
||||||
|
bank_id = session_info.memory_bank_id
|
||||||
|
|
||||||
|
return bank_id
|
||||||
|
|
||||||
|
async def attachment_message(
|
||||||
|
self, tempdir: str, urls: List[URL]
|
||||||
|
) -> List[TextContentItem]:
|
||||||
|
content = []
|
||||||
|
|
||||||
|
for url in urls:
|
||||||
|
uri = url.uri
|
||||||
|
if uri.startswith("file://"):
|
||||||
|
filepath = uri[len("file://") :]
|
||||||
|
elif uri.startswith("http"):
|
||||||
|
path = urlparse(uri).path
|
||||||
|
basename = os.path.basename(path)
|
||||||
|
filepath = f"{tempdir}/{make_random_string() + basename}"
|
||||||
|
log.info(f"Downloading {url} -> {filepath}")
|
||||||
|
|
||||||
|
async with httpx.AsyncClient() as client:
|
||||||
|
r = await client.get(uri)
|
||||||
|
resp = r.text
|
||||||
|
with open(filepath, "w") as fp:
|
||||||
|
fp.write(resp)
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unsupported URL {url}")
|
||||||
|
|
||||||
|
content.append(
|
||||||
|
TextContentItem(
|
||||||
|
text=f'# There is a file accessible to you at "{filepath}"\n'
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
return content
|
||||||
|
|
||||||
|
async def _retrieve_context(
|
||||||
|
self, session_id: str, messages: List[Message]
|
||||||
|
) -> Optional[List[InterleavedContent]]:
|
||||||
|
bank_ids = []
|
||||||
|
|
||||||
|
bank_ids.extend(c.bank_id for c in self.config.memory_bank_configs)
|
||||||
|
|
||||||
|
session_info = await self.get_session_info(session_id)
|
||||||
|
if session_info.memory_bank_id:
|
||||||
|
bank_ids.append(session_info.memory_bank_id)
|
||||||
|
|
||||||
|
if not bank_ids:
|
||||||
|
# this can happen if the per-session memory bank is not yet populated
|
||||||
|
# (i.e., no prior turns uploaded an Attachment)
|
||||||
|
return None
|
||||||
|
|
||||||
|
query = await generate_rag_query(
|
||||||
|
self.config.query_generator_config,
|
||||||
|
messages,
|
||||||
|
inference_api=self.inference_api,
|
||||||
|
)
|
||||||
|
tasks = [
|
||||||
|
self.memory_api.query_documents(
|
||||||
|
bank_id=bank_id,
|
||||||
|
query=query,
|
||||||
|
params={
|
||||||
|
"max_chunks": 5,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
for bank_id in bank_ids
|
||||||
|
]
|
||||||
|
results: List[QueryDocumentsResponse] = await asyncio.gather(*tasks)
|
||||||
|
chunks = [c for r in results for c in r.chunks]
|
||||||
|
scores = [s for r in results for s in r.scores]
|
||||||
|
|
||||||
|
if not chunks:
|
||||||
|
return None
|
||||||
|
|
||||||
|
# sort by score
|
||||||
|
chunks, scores = zip(
|
||||||
|
*sorted(zip(chunks, scores), key=lambda x: x[1], reverse=True)
|
||||||
|
)
|
||||||
|
|
||||||
|
tokens = 0
|
||||||
|
picked = []
|
||||||
|
for c in chunks[: self.config.max_chunks]:
|
||||||
|
tokens += c.token_count
|
||||||
|
if tokens > self.config.max_tokens_in_context:
|
||||||
|
log.error(
|
||||||
|
f"Using {len(picked)} chunks; reached max tokens in context: {tokens}",
|
||||||
|
)
|
||||||
|
break
|
||||||
|
picked.append(f"id:{c.document_id}; content:{c.content}")
|
||||||
|
|
||||||
|
return [
|
||||||
|
"Here are the retrieved documents for relevant context:\n=== START-RETRIEVED-CONTEXT ===\n",
|
||||||
|
*picked,
|
||||||
|
"\n=== END-RETRIEVED-CONTEXT ===\n",
|
||||||
|
]
|
||||||
|
|
||||||
|
async def _process_attachments(
|
||||||
|
self, session_id: str, attachments: List[Attachment]
|
||||||
|
):
|
||||||
|
bank_id = await self._ensure_memory_bank(session_id)
|
||||||
|
|
||||||
|
documents = [
|
||||||
|
MemoryBankDocument(
|
||||||
|
document_id=str(uuid.uuid4()),
|
||||||
|
content=a.content,
|
||||||
|
mime_type=a.mime_type,
|
||||||
|
metadata={},
|
||||||
|
)
|
||||||
|
for a in attachments
|
||||||
|
if isinstance(a.content, str)
|
||||||
|
]
|
||||||
|
await self.memory_api.insert_documents(bank_id, documents)
|
||||||
|
|
||||||
|
urls = [a.content for a in attachments if isinstance(a.content, URL)]
|
||||||
|
# TODO: we need to migrate URL away from str type
|
||||||
|
pattern = re.compile("^(https?://|file://|data:)")
|
||||||
|
urls += [URL(uri=a.content) for a in attachments if pattern.match(a.content)]
|
||||||
|
return await self.attachment_message(self.tempdir, urls)
|
||||||
|
|
||||||
|
async def invoke_tool(
|
||||||
|
self, tool_name: str, args: Dict[str, Any]
|
||||||
|
) -> ToolInvocationResult:
|
||||||
|
if args["session_id"] is None:
|
||||||
|
raise ValueError("session_id is required")
|
||||||
|
|
||||||
|
context = await self._retrieve_context(
|
||||||
|
args["session_id"], args["input_messages"]
|
||||||
|
)
|
||||||
|
if context is None:
|
||||||
|
context = []
|
||||||
|
attachments = args["attachments"]
|
||||||
|
if attachments and len(attachments) > 0:
|
||||||
|
context += await self._process_attachments(args["session_id"], attachments)
|
||||||
|
return ToolInvocationResult(
|
||||||
|
content=concat_interleaved_content(context), error_code=0
|
||||||
|
)
|
|
@ -35,6 +35,8 @@ def available_providers() -> List[ProviderSpec]:
|
||||||
Api.safety,
|
Api.safety,
|
||||||
Api.memory,
|
Api.memory,
|
||||||
Api.memory_banks,
|
Api.memory_banks,
|
||||||
|
Api.tool_runtime,
|
||||||
|
Api.tool_groups,
|
||||||
],
|
],
|
||||||
),
|
),
|
||||||
remote_provider_spec(
|
remote_provider_spec(
|
||||||
|
|
|
@ -25,6 +25,14 @@ def available_providers() -> List[ProviderSpec]:
|
||||||
config_class="llama_stack.providers.inline.tool_runtime.brave_search.config.BraveSearchToolConfig",
|
config_class="llama_stack.providers.inline.tool_runtime.brave_search.config.BraveSearchToolConfig",
|
||||||
provider_data_validator="llama_stack.providers.inline.tool_runtime.brave_search.BraveSearchToolProviderDataValidator",
|
provider_data_validator="llama_stack.providers.inline.tool_runtime.brave_search.BraveSearchToolProviderDataValidator",
|
||||||
),
|
),
|
||||||
|
InlineProviderSpec(
|
||||||
|
api=Api.tool_runtime,
|
||||||
|
provider_type="inline::memory-runtime",
|
||||||
|
pip_packages=[],
|
||||||
|
module="llama_stack.providers.inline.tool_runtime.memory",
|
||||||
|
config_class="llama_stack.providers.inline.tool_runtime.memory.config.MemoryToolConfig",
|
||||||
|
api_dependencies=[Api.memory, Api.memory_banks, Api.inference],
|
||||||
|
),
|
||||||
remote_provider_spec(
|
remote_provider_spec(
|
||||||
api=Api.tool_runtime,
|
api=Api.tool_runtime,
|
||||||
adapter=AdapterSpec(
|
adapter=AdapterSpec(
|
||||||
|
|
|
@ -7,12 +7,10 @@
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from ..conftest import get_provider_fixture_overrides
|
from ..conftest import get_provider_fixture_overrides
|
||||||
|
|
||||||
from ..inference.fixtures import INFERENCE_FIXTURES
|
from ..inference.fixtures import INFERENCE_FIXTURES
|
||||||
from ..memory.fixtures import MEMORY_FIXTURES
|
from ..memory.fixtures import MEMORY_FIXTURES
|
||||||
from ..safety.fixtures import SAFETY_FIXTURES, safety_model_from_shield
|
from ..safety.fixtures import SAFETY_FIXTURES, safety_model_from_shield
|
||||||
from .fixtures import AGENTS_FIXTURES
|
from .fixtures import AGENTS_FIXTURES, TOOL_RUNTIME_FIXTURES
|
||||||
|
|
||||||
|
|
||||||
DEFAULT_PROVIDER_COMBINATIONS = [
|
DEFAULT_PROVIDER_COMBINATIONS = [
|
||||||
pytest.param(
|
pytest.param(
|
||||||
|
@ -21,6 +19,7 @@ DEFAULT_PROVIDER_COMBINATIONS = [
|
||||||
"safety": "llama_guard",
|
"safety": "llama_guard",
|
||||||
"memory": "faiss",
|
"memory": "faiss",
|
||||||
"agents": "meta_reference",
|
"agents": "meta_reference",
|
||||||
|
"tool_runtime": "memory",
|
||||||
},
|
},
|
||||||
id="meta_reference",
|
id="meta_reference",
|
||||||
marks=pytest.mark.meta_reference,
|
marks=pytest.mark.meta_reference,
|
||||||
|
@ -31,6 +30,7 @@ DEFAULT_PROVIDER_COMBINATIONS = [
|
||||||
"safety": "llama_guard",
|
"safety": "llama_guard",
|
||||||
"memory": "faiss",
|
"memory": "faiss",
|
||||||
"agents": "meta_reference",
|
"agents": "meta_reference",
|
||||||
|
"tool_runtime": "memory",
|
||||||
},
|
},
|
||||||
id="ollama",
|
id="ollama",
|
||||||
marks=pytest.mark.ollama,
|
marks=pytest.mark.ollama,
|
||||||
|
@ -42,6 +42,7 @@ DEFAULT_PROVIDER_COMBINATIONS = [
|
||||||
# make this work with Weaviate which is what the together distro supports
|
# make this work with Weaviate which is what the together distro supports
|
||||||
"memory": "faiss",
|
"memory": "faiss",
|
||||||
"agents": "meta_reference",
|
"agents": "meta_reference",
|
||||||
|
"tool_runtime": "memory",
|
||||||
},
|
},
|
||||||
id="together",
|
id="together",
|
||||||
marks=pytest.mark.together,
|
marks=pytest.mark.together,
|
||||||
|
@ -52,6 +53,7 @@ DEFAULT_PROVIDER_COMBINATIONS = [
|
||||||
"safety": "llama_guard",
|
"safety": "llama_guard",
|
||||||
"memory": "faiss",
|
"memory": "faiss",
|
||||||
"agents": "meta_reference",
|
"agents": "meta_reference",
|
||||||
|
"tool_runtime": "memory",
|
||||||
},
|
},
|
||||||
id="fireworks",
|
id="fireworks",
|
||||||
marks=pytest.mark.fireworks,
|
marks=pytest.mark.fireworks,
|
||||||
|
@ -62,6 +64,7 @@ DEFAULT_PROVIDER_COMBINATIONS = [
|
||||||
"safety": "remote",
|
"safety": "remote",
|
||||||
"memory": "remote",
|
"memory": "remote",
|
||||||
"agents": "remote",
|
"agents": "remote",
|
||||||
|
"tool_runtime": "memory",
|
||||||
},
|
},
|
||||||
id="remote",
|
id="remote",
|
||||||
marks=pytest.mark.remote,
|
marks=pytest.mark.remote,
|
||||||
|
@ -117,6 +120,7 @@ def pytest_generate_tests(metafunc):
|
||||||
"safety": SAFETY_FIXTURES,
|
"safety": SAFETY_FIXTURES,
|
||||||
"memory": MEMORY_FIXTURES,
|
"memory": MEMORY_FIXTURES,
|
||||||
"agents": AGENTS_FIXTURES,
|
"agents": AGENTS_FIXTURES,
|
||||||
|
"tool_runtime": TOOL_RUNTIME_FIXTURES,
|
||||||
}
|
}
|
||||||
combinations = (
|
combinations = (
|
||||||
get_provider_fixture_overrides(metafunc.config, available_fixtures)
|
get_provider_fixture_overrides(metafunc.config, available_fixtures)
|
||||||
|
|
|
@ -10,14 +10,19 @@ import pytest
|
||||||
import pytest_asyncio
|
import pytest_asyncio
|
||||||
|
|
||||||
from llama_stack.apis.models import ModelInput, ModelType
|
from llama_stack.apis.models import ModelInput, ModelType
|
||||||
|
from llama_stack.apis.tools import (
|
||||||
|
ToolDef,
|
||||||
|
ToolGroupInput,
|
||||||
|
ToolParameter,
|
||||||
|
UserDefinedToolGroupDef,
|
||||||
|
)
|
||||||
from llama_stack.distribution.datatypes import Api, Provider
|
from llama_stack.distribution.datatypes import Api, Provider
|
||||||
|
|
||||||
from llama_stack.providers.inline.agents.meta_reference import (
|
from llama_stack.providers.inline.agents.meta_reference import (
|
||||||
MetaReferenceAgentsImplConfig,
|
MetaReferenceAgentsImplConfig,
|
||||||
)
|
)
|
||||||
|
|
||||||
from llama_stack.providers.tests.resolver import construct_stack_for_test
|
from llama_stack.providers.tests.resolver import construct_stack_for_test
|
||||||
from llama_stack.providers.utils.kvstore.config import SqliteKVStoreConfig
|
from llama_stack.providers.utils.kvstore.config import SqliteKVStoreConfig
|
||||||
|
|
||||||
from ..conftest import ProviderFixture, remote_stack_fixture
|
from ..conftest import ProviderFixture, remote_stack_fixture
|
||||||
|
|
||||||
|
|
||||||
|
@ -55,7 +60,21 @@ def agents_meta_reference() -> ProviderFixture:
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="session")
|
||||||
|
def tool_runtime_memory() -> ProviderFixture:
|
||||||
|
return ProviderFixture(
|
||||||
|
providers=[
|
||||||
|
Provider(
|
||||||
|
provider_id="memory-runtime",
|
||||||
|
provider_type="inline::memory-runtime",
|
||||||
|
config={},
|
||||||
|
)
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
AGENTS_FIXTURES = ["meta_reference", "remote"]
|
AGENTS_FIXTURES = ["meta_reference", "remote"]
|
||||||
|
TOOL_RUNTIME_FIXTURES = ["memory"]
|
||||||
|
|
||||||
|
|
||||||
@pytest_asyncio.fixture(scope="session")
|
@pytest_asyncio.fixture(scope="session")
|
||||||
|
@ -64,7 +83,7 @@ async def agents_stack(request, inference_model, safety_shield):
|
||||||
|
|
||||||
providers = {}
|
providers = {}
|
||||||
provider_data = {}
|
provider_data = {}
|
||||||
for key in ["inference", "safety", "memory", "agents"]:
|
for key in ["inference", "safety", "memory", "agents", "tool_runtime"]:
|
||||||
fixture = request.getfixturevalue(f"{key}_{fixture_dict[key]}")
|
fixture = request.getfixturevalue(f"{key}_{fixture_dict[key]}")
|
||||||
providers[key] = fixture.providers
|
providers[key] = fixture.providers
|
||||||
if key == "inference":
|
if key == "inference":
|
||||||
|
@ -111,12 +130,48 @@ async def agents_stack(request, inference_model, safety_shield):
|
||||||
metadata={"embedding_dimension": 384},
|
metadata={"embedding_dimension": 384},
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
tool_groups = [
|
||||||
|
ToolGroupInput(
|
||||||
|
tool_group_id="memory_group",
|
||||||
|
tool_group=UserDefinedToolGroupDef(
|
||||||
|
tools=[
|
||||||
|
ToolDef(
|
||||||
|
name="memory",
|
||||||
|
description="memory",
|
||||||
|
parameters=[
|
||||||
|
ToolParameter(
|
||||||
|
name="session_id",
|
||||||
|
description="session id",
|
||||||
|
parameter_type="string",
|
||||||
|
required=True,
|
||||||
|
),
|
||||||
|
ToolParameter(
|
||||||
|
name="input_messages",
|
||||||
|
description="messages",
|
||||||
|
parameter_type="list",
|
||||||
|
required=True,
|
||||||
|
),
|
||||||
|
ToolParameter(
|
||||||
|
name="attachments",
|
||||||
|
description="attachments",
|
||||||
|
parameter_type="list",
|
||||||
|
required=False,
|
||||||
|
),
|
||||||
|
],
|
||||||
|
metadata={},
|
||||||
|
)
|
||||||
|
],
|
||||||
|
),
|
||||||
|
provider_id="memory-runtime",
|
||||||
|
)
|
||||||
|
]
|
||||||
|
|
||||||
test_stack = await construct_stack_for_test(
|
test_stack = await construct_stack_for_test(
|
||||||
[Api.agents, Api.inference, Api.safety, Api.memory],
|
[Api.agents, Api.inference, Api.safety, Api.memory, Api.tool_runtime],
|
||||||
providers,
|
providers,
|
||||||
provider_data,
|
provider_data,
|
||||||
models=models,
|
models=models,
|
||||||
shields=[safety_shield] if safety_shield else [],
|
shields=[safety_shield] if safety_shield else [],
|
||||||
|
tool_groups=tool_groups,
|
||||||
)
|
)
|
||||||
return test_stack
|
return test_stack
|
||||||
|
|
|
@ -35,7 +35,6 @@ from llama_stack.providers.datatypes import Api
|
||||||
#
|
#
|
||||||
# pytest -v -s llama_stack/providers/tests/agents/test_agents.py
|
# pytest -v -s llama_stack/providers/tests/agents/test_agents.py
|
||||||
# -m "meta_reference"
|
# -m "meta_reference"
|
||||||
|
|
||||||
from .fixtures import pick_inference_model
|
from .fixtures import pick_inference_model
|
||||||
from .utils import create_agent_session
|
from .utils import create_agent_session
|
||||||
|
|
||||||
|
@ -255,17 +254,8 @@ class TestAgents:
|
||||||
agent_config = AgentConfig(
|
agent_config = AgentConfig(
|
||||||
**{
|
**{
|
||||||
**common_params,
|
**common_params,
|
||||||
"tools": [
|
"tools": [],
|
||||||
MemoryToolDefinition(
|
"preprocessing_tools": ["memory"],
|
||||||
memory_bank_configs=[],
|
|
||||||
query_generator_config={
|
|
||||||
"type": "default",
|
|
||||||
"sep": " ",
|
|
||||||
},
|
|
||||||
max_tokens_in_context=4096,
|
|
||||||
max_chunks=10,
|
|
||||||
),
|
|
||||||
],
|
|
||||||
"tool_choice": ToolChoice.auto,
|
"tool_choice": ToolChoice.auto,
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
|
@ -16,7 +16,7 @@ from llama_stack.apis.memory_banks import MemoryBankInput
|
||||||
from llama_stack.apis.models import ModelInput
|
from llama_stack.apis.models import ModelInput
|
||||||
from llama_stack.apis.scoring_functions import ScoringFnInput
|
from llama_stack.apis.scoring_functions import ScoringFnInput
|
||||||
from llama_stack.apis.shields import ShieldInput
|
from llama_stack.apis.shields import ShieldInput
|
||||||
|
from llama_stack.apis.tools import ToolGroupInput
|
||||||
from llama_stack.distribution.build import print_pip_install_help
|
from llama_stack.distribution.build import print_pip_install_help
|
||||||
from llama_stack.distribution.configure import parse_and_maybe_upgrade_config
|
from llama_stack.distribution.configure import parse_and_maybe_upgrade_config
|
||||||
from llama_stack.distribution.datatypes import Provider, StackRunConfig
|
from llama_stack.distribution.datatypes import Provider, StackRunConfig
|
||||||
|
@ -43,6 +43,7 @@ async def construct_stack_for_test(
|
||||||
datasets: Optional[List[DatasetInput]] = None,
|
datasets: Optional[List[DatasetInput]] = None,
|
||||||
scoring_fns: Optional[List[ScoringFnInput]] = None,
|
scoring_fns: Optional[List[ScoringFnInput]] = None,
|
||||||
eval_tasks: Optional[List[EvalTaskInput]] = None,
|
eval_tasks: Optional[List[EvalTaskInput]] = None,
|
||||||
|
tool_groups: Optional[List[ToolGroupInput]] = None,
|
||||||
) -> TestStack:
|
) -> TestStack:
|
||||||
sqlite_file = tempfile.NamedTemporaryFile(delete=False, suffix=".db")
|
sqlite_file = tempfile.NamedTemporaryFile(delete=False, suffix=".db")
|
||||||
run_config = dict(
|
run_config = dict(
|
||||||
|
@ -56,6 +57,7 @@ async def construct_stack_for_test(
|
||||||
datasets=datasets or [],
|
datasets=datasets or [],
|
||||||
scoring_fns=scoring_fns or [],
|
scoring_fns=scoring_fns or [],
|
||||||
eval_tasks=eval_tasks or [],
|
eval_tasks=eval_tasks or [],
|
||||||
|
tool_groups=tool_groups or [],
|
||||||
)
|
)
|
||||||
run_config = parse_and_maybe_upgrade_config(run_config)
|
run_config = parse_and_maybe_upgrade_config(run_config)
|
||||||
try:
|
try:
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue