Several smaller fixes to make adapters work

Also, reorganized the pattern of __init__ inside providers so
configuration can stay lightweight
This commit is contained in:
Ashwin Bharambe 2024-08-28 09:42:08 -07:00
parent 2a1552a5eb
commit 45987996c4
23 changed files with 164 additions and 160 deletions

View file

@ -194,11 +194,6 @@ async def run_rag(host: str, port: int):
MemoryToolDefinition(
max_tokens_in_context=2048,
memory_bank_configs=[],
# memory_bank_configs=[
# AgenticSystemVectorMemoryBankConfig(
# bank_id="970b8790-268e-4fd3-a9b1-d0e597e975ed",
# )
# ],
),
]
@ -210,8 +205,9 @@ async def run_rag(host: str, port: int):
await _run_agent(api, tool_definitions, user_prompts, attachments)
def main(host: str, port: int):
asyncio.run(run_rag(host, port))
def main(host: str, port: int, rag: bool = False):
fn = run_rag if rag else run_main
asyncio.run(fn(host, port))
if __name__ == "__main__":

View file

@ -4,5 +4,27 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from .agentic_system import get_provider_impl # noqa
from .config import MetaReferenceImplConfig # noqa
from typing import Dict
from llama_toolchain.distribution.datatypes import Api, ProviderSpec
from .config import MetaReferenceImplConfig
async def get_provider_impl(
config: MetaReferenceImplConfig, deps: Dict[Api, ProviderSpec]
):
from .agentic_system import MetaReferenceAgenticSystemImpl
assert isinstance(
config, MetaReferenceImplConfig
), f"Unexpected config type: {type(config)}"
impl = MetaReferenceAgenticSystemImpl(
config,
deps[Api.inference],
deps[Api.memory],
deps[Api.safety],
)
await impl.initialize()
return impl

View file

@ -8,9 +8,8 @@
import logging
import os
import uuid
from typing import AsyncGenerator, Dict
from typing import AsyncGenerator
from llama_toolchain.distribution.datatypes import Api, ProviderSpec
from llama_toolchain.inference.api import Inference
from llama_toolchain.memory.api import Memory
from llama_toolchain.safety.api import Safety
@ -31,23 +30,6 @@ logger = logging.getLogger()
logger.setLevel(logging.INFO)
async def get_provider_impl(
config: MetaReferenceImplConfig, deps: Dict[Api, ProviderSpec]
):
assert isinstance(
config, MetaReferenceImplConfig
), f"Unexpected config type: {type(config)}"
impl = MetaReferenceAgenticSystemImpl(
config,
deps[Api.inference],
deps[Api.memory],
deps[Api.safety],
)
await impl.initialize()
return impl
AGENT_INSTANCES_BY_ID = {}