mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-07-30 07:39:38 +00:00
inference + memory + agents tests now pass with "remote" providers
This commit is contained in:
parent
fc66131fea
commit
386372dd24
6 changed files with 127 additions and 91 deletions
|
@ -8,51 +8,51 @@ import inspect
|
||||||
|
|
||||||
import json
|
import json
|
||||||
from collections.abc import AsyncIterator
|
from collections.abc import AsyncIterator
|
||||||
|
from enum import Enum
|
||||||
from typing import Any, get_args, get_origin, Type, Union
|
from typing import Any, get_args, get_origin, Type, Union
|
||||||
|
|
||||||
import httpx
|
import httpx
|
||||||
|
|
||||||
from llama_models.schema_utils import WebMethod
|
|
||||||
from pydantic import BaseModel, parse_obj_as
|
from pydantic import BaseModel, parse_obj_as
|
||||||
from termcolor import cprint
|
from termcolor import cprint
|
||||||
|
|
||||||
|
from llama_stack.providers.datatypes import RemoteProviderConfig
|
||||||
def extract_non_async_iterator_type(type_hint):
|
|
||||||
if get_origin(type_hint) is Union:
|
|
||||||
args = get_args(type_hint)
|
|
||||||
for arg in args:
|
|
||||||
if not issubclass(get_origin(arg) or arg, AsyncIterator):
|
|
||||||
return arg
|
|
||||||
return None
|
|
||||||
|
|
||||||
|
|
||||||
def extract_async_iterator_type(type_hint):
|
|
||||||
if get_origin(type_hint) is Union:
|
|
||||||
args = get_args(type_hint)
|
|
||||||
for arg in args:
|
|
||||||
if issubclass(get_origin(arg) or arg, AsyncIterator):
|
|
||||||
inner_args = get_args(arg)
|
|
||||||
return inner_args[0]
|
|
||||||
return None
|
|
||||||
|
|
||||||
|
|
||||||
_CLIENT_CLASSES = {}
|
_CLIENT_CLASSES = {}
|
||||||
|
|
||||||
|
|
||||||
def create_api_client_class(protocol) -> Type:
|
async def get_client_impl(
|
||||||
|
protocol, additional_protocol, config: RemoteProviderConfig, _deps: Any
|
||||||
|
):
|
||||||
|
client_class = create_api_client_class(protocol, additional_protocol)
|
||||||
|
impl = client_class(config.url)
|
||||||
|
await impl.initialize()
|
||||||
|
return impl
|
||||||
|
|
||||||
|
|
||||||
|
def create_api_client_class(protocol, additional_protocol) -> Type:
|
||||||
if protocol in _CLIENT_CLASSES:
|
if protocol in _CLIENT_CLASSES:
|
||||||
return _CLIENT_CLASSES[protocol]
|
return _CLIENT_CLASSES[protocol]
|
||||||
|
|
||||||
|
protocols = [protocol, additional_protocol] if additional_protocol else [protocol]
|
||||||
|
|
||||||
class APIClient:
|
class APIClient:
|
||||||
def __init__(self, base_url: str):
|
def __init__(self, base_url: str):
|
||||||
|
print(f"({protocol.__name__}) Connecting to {base_url}")
|
||||||
self.base_url = base_url.rstrip("/")
|
self.base_url = base_url.rstrip("/")
|
||||||
self.routes = {}
|
self.routes = {}
|
||||||
|
|
||||||
# Store routes for this protocol
|
# Store routes for this protocol
|
||||||
for name, method in inspect.getmembers(protocol):
|
for p in protocols:
|
||||||
if hasattr(method, "__webmethod__"):
|
for name, method in inspect.getmembers(p):
|
||||||
sig = inspect.signature(method)
|
if hasattr(method, "__webmethod__"):
|
||||||
self.routes[name] = (method.__webmethod__, sig)
|
sig = inspect.signature(method)
|
||||||
|
self.routes[name] = (method.__webmethod__, sig)
|
||||||
|
|
||||||
|
async def initialize(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
async def shutdown(self):
|
||||||
|
pass
|
||||||
|
|
||||||
async def __acall__(self, method_name: str, *args, **kwargs) -> Any:
|
async def __acall__(self, method_name: str, *args, **kwargs) -> Any:
|
||||||
assert method_name in self.routes, f"Unknown endpoint: {method_name}"
|
assert method_name in self.routes, f"Unknown endpoint: {method_name}"
|
||||||
|
@ -65,21 +65,23 @@ def create_api_client_class(protocol) -> Type:
|
||||||
return await self._call_non_streaming(method_name, *args, **kwargs)
|
return await self._call_non_streaming(method_name, *args, **kwargs)
|
||||||
|
|
||||||
async def _call_non_streaming(self, method_name: str, *args, **kwargs) -> Any:
|
async def _call_non_streaming(self, method_name: str, *args, **kwargs) -> Any:
|
||||||
webmethod, sig = self.routes[method_name]
|
_, sig = self.routes[method_name]
|
||||||
|
|
||||||
return_type = extract_non_async_iterator_type(sig.return_annotation)
|
if sig.return_annotation is None:
|
||||||
assert (
|
return_type = None
|
||||||
return_type
|
else:
|
||||||
), f"Could not extract return type for {sig.return_annotation}"
|
return_type = extract_non_async_iterator_type(sig.return_annotation)
|
||||||
cprint(f"{return_type=}", "yellow")
|
assert (
|
||||||
|
return_type
|
||||||
|
), f"Could not extract return type for {sig.return_annotation}"
|
||||||
|
|
||||||
async with httpx.AsyncClient() as client:
|
async with httpx.AsyncClient() as client:
|
||||||
params = self.httpx_request_params(webmethod, **kwargs)
|
params = self.httpx_request_params(method_name, *args, **kwargs)
|
||||||
response = await client.request(**params)
|
response = await client.request(**params)
|
||||||
response.raise_for_status()
|
response.raise_for_status()
|
||||||
|
|
||||||
j = response.json()
|
j = response.json()
|
||||||
if not j:
|
if j is None:
|
||||||
return None
|
return None
|
||||||
return parse_obj_as(return_type, j)
|
return parse_obj_as(return_type, j)
|
||||||
|
|
||||||
|
@ -90,10 +92,9 @@ def create_api_client_class(protocol) -> Type:
|
||||||
assert (
|
assert (
|
||||||
return_type
|
return_type
|
||||||
), f"Could not extract return type for {sig.return_annotation}"
|
), f"Could not extract return type for {sig.return_annotation}"
|
||||||
cprint(f"{return_type=}", "yellow")
|
|
||||||
|
|
||||||
async with httpx.AsyncClient() as client:
|
async with httpx.AsyncClient() as client:
|
||||||
params = self.httpx_request_params(webmethod, **kwargs)
|
params = self.httpx_request_params(method_name, *args, **kwargs)
|
||||||
async with client.stream(**params) as response:
|
async with client.stream(**params) as response:
|
||||||
response.raise_for_status()
|
response.raise_for_status()
|
||||||
|
|
||||||
|
@ -110,7 +111,15 @@ def create_api_client_class(protocol) -> Type:
|
||||||
print(data)
|
print(data)
|
||||||
print(f"Error with parsing or validation: {e}")
|
print(f"Error with parsing or validation: {e}")
|
||||||
|
|
||||||
def httpx_request_params(self, webmethod: WebMethod, **kwargs) -> dict:
|
def httpx_request_params(self, method_name: str, *args, **kwargs) -> dict:
|
||||||
|
webmethod, sig = self.routes[method_name]
|
||||||
|
|
||||||
|
parameters = list(sig.parameters.values())[1:] # skip `self`
|
||||||
|
for i, param in enumerate(parameters):
|
||||||
|
if i >= len(args):
|
||||||
|
break
|
||||||
|
kwargs[param.name] = args[i]
|
||||||
|
|
||||||
url = f"{self.base_url}{webmethod.route}"
|
url = f"{self.base_url}{webmethod.route}"
|
||||||
|
|
||||||
def convert(value):
|
def convert(value):
|
||||||
|
@ -119,7 +128,9 @@ def create_api_client_class(protocol) -> Type:
|
||||||
elif isinstance(value, dict):
|
elif isinstance(value, dict):
|
||||||
return {k: convert(v) for k, v in value.items()}
|
return {k: convert(v) for k, v in value.items()}
|
||||||
elif isinstance(value, BaseModel):
|
elif isinstance(value, BaseModel):
|
||||||
return json.loads(value.json())
|
return json.loads(value.model_dump_json())
|
||||||
|
elif isinstance(value, Enum):
|
||||||
|
return value.value
|
||||||
else:
|
else:
|
||||||
return value
|
return value
|
||||||
|
|
||||||
|
@ -140,16 +151,17 @@ def create_api_client_class(protocol) -> Type:
|
||||||
)
|
)
|
||||||
|
|
||||||
# Add protocol methods to the wrapper
|
# Add protocol methods to the wrapper
|
||||||
for name, method in inspect.getmembers(protocol):
|
for p in protocols:
|
||||||
if hasattr(method, "__webmethod__"):
|
for name, method in inspect.getmembers(p):
|
||||||
|
if hasattr(method, "__webmethod__"):
|
||||||
|
|
||||||
async def method_impl(self, *args, method_name=name, **kwargs):
|
async def method_impl(self, *args, method_name=name, **kwargs):
|
||||||
return await self.__acall__(method_name, *args, **kwargs)
|
return await self.__acall__(method_name, *args, **kwargs)
|
||||||
|
|
||||||
method_impl.__name__ = name
|
method_impl.__name__ = name
|
||||||
method_impl.__qualname__ = f"APIClient.{name}"
|
method_impl.__qualname__ = f"APIClient.{name}"
|
||||||
method_impl.__signature__ = inspect.signature(method)
|
method_impl.__signature__ = inspect.signature(method)
|
||||||
setattr(APIClient, name, method_impl)
|
setattr(APIClient, name, method_impl)
|
||||||
|
|
||||||
# Name the class after the protocol
|
# Name the class after the protocol
|
||||||
APIClient.__name__ = f"{protocol.__name__}Client"
|
APIClient.__name__ = f"{protocol.__name__}Client"
|
||||||
|
@ -157,6 +169,26 @@ def create_api_client_class(protocol) -> Type:
|
||||||
return APIClient
|
return APIClient
|
||||||
|
|
||||||
|
|
||||||
|
# not quite general these methods are
|
||||||
|
def extract_non_async_iterator_type(type_hint):
|
||||||
|
if get_origin(type_hint) is Union:
|
||||||
|
args = get_args(type_hint)
|
||||||
|
for arg in args:
|
||||||
|
if not issubclass(get_origin(arg) or arg, AsyncIterator):
|
||||||
|
return arg
|
||||||
|
return type_hint
|
||||||
|
|
||||||
|
|
||||||
|
def extract_async_iterator_type(type_hint):
|
||||||
|
if get_origin(type_hint) is Union:
|
||||||
|
args = get_args(type_hint)
|
||||||
|
for arg in args:
|
||||||
|
if issubclass(get_origin(arg) or arg, AsyncIterator):
|
||||||
|
inner_args = get_args(arg)
|
||||||
|
return inner_args[0]
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
async def example(model: str = None):
|
async def example(model: str = None):
|
||||||
from llama_stack.apis.inference import Inference, UserMessage # noqa: F403
|
from llama_stack.apis.inference import Inference, UserMessage # noqa: F403
|
||||||
from llama_stack.apis.inference.event_logger import EventLogger
|
from llama_stack.apis.inference.event_logger import EventLogger
|
||||||
|
|
|
@ -40,19 +40,21 @@ def api_protocol_map() -> Dict[Api, Any]:
|
||||||
Api.safety: Safety,
|
Api.safety: Safety,
|
||||||
Api.shields: Shields,
|
Api.shields: Shields,
|
||||||
Api.telemetry: Telemetry,
|
Api.telemetry: Telemetry,
|
||||||
Api.datasets: Datasets,
|
|
||||||
Api.datasetio: DatasetIO,
|
Api.datasetio: DatasetIO,
|
||||||
Api.scoring_functions: ScoringFunctions,
|
Api.datasets: Datasets,
|
||||||
Api.scoring: Scoring,
|
Api.scoring: Scoring,
|
||||||
|
Api.scoring_functions: ScoringFunctions,
|
||||||
Api.eval: Eval,
|
Api.eval: Eval,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
def additional_protocols_map() -> Dict[Api, Any]:
|
def additional_protocols_map() -> Dict[Api, Any]:
|
||||||
return {
|
return {
|
||||||
Api.inference: ModelsProtocolPrivate,
|
Api.inference: (ModelsProtocolPrivate, Models),
|
||||||
Api.memory: MemoryBanksProtocolPrivate,
|
Api.memory: (MemoryBanksProtocolPrivate, MemoryBanks),
|
||||||
Api.safety: ShieldsProtocolPrivate,
|
Api.safety: (ShieldsProtocolPrivate, Shields),
|
||||||
|
Api.datasetio: (DatasetsProtocolPrivate, Datasets),
|
||||||
|
Api.scoring: (ScoringFunctionsProtocolPrivate, ScoringFunctions),
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@ -112,8 +114,6 @@ async def resolve_impls(
|
||||||
if info.router_api.value not in apis_to_serve:
|
if info.router_api.value not in apis_to_serve:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
available_providers = providers_with_specs[f"inner-{info.router_api.value}"]
|
|
||||||
|
|
||||||
providers_with_specs[info.routing_table_api.value] = {
|
providers_with_specs[info.routing_table_api.value] = {
|
||||||
"__builtin__": ProviderWithSpec(
|
"__builtin__": ProviderWithSpec(
|
||||||
provider_id="__routing_table__",
|
provider_id="__routing_table__",
|
||||||
|
@ -246,14 +246,21 @@ async def instantiate_provider(
|
||||||
|
|
||||||
args = []
|
args = []
|
||||||
if isinstance(provider_spec, RemoteProviderSpec):
|
if isinstance(provider_spec, RemoteProviderSpec):
|
||||||
if provider_spec.adapter:
|
|
||||||
method = "get_adapter_impl"
|
|
||||||
else:
|
|
||||||
method = "get_client_impl"
|
|
||||||
|
|
||||||
config_type = instantiate_class_type(provider_spec.config_class)
|
config_type = instantiate_class_type(provider_spec.config_class)
|
||||||
config = config_type(**provider.config)
|
config = config_type(**provider.config)
|
||||||
args = [config, deps]
|
|
||||||
|
if provider_spec.adapter:
|
||||||
|
method = "get_adapter_impl"
|
||||||
|
args = [config, deps]
|
||||||
|
else:
|
||||||
|
method = "get_client_impl"
|
||||||
|
protocol = protocols[provider_spec.api]
|
||||||
|
if provider_spec.api in additional_protocols:
|
||||||
|
_, additional_protocol = additional_protocols[provider_spec.api]
|
||||||
|
else:
|
||||||
|
additional_protocol = None
|
||||||
|
args = [protocol, additional_protocol, config, deps]
|
||||||
|
|
||||||
elif isinstance(provider_spec, AutoRoutedProviderSpec):
|
elif isinstance(provider_spec, AutoRoutedProviderSpec):
|
||||||
method = "get_auto_router_impl"
|
method = "get_auto_router_impl"
|
||||||
|
|
||||||
|
@ -282,7 +289,7 @@ async def instantiate_provider(
|
||||||
not isinstance(provider_spec, AutoRoutedProviderSpec)
|
not isinstance(provider_spec, AutoRoutedProviderSpec)
|
||||||
and provider_spec.api in additional_protocols
|
and provider_spec.api in additional_protocols
|
||||||
):
|
):
|
||||||
additional_api = additional_protocols[provider_spec.api]
|
additional_api, _ = additional_protocols[provider_spec.api]
|
||||||
check_protocol_compliance(impl, additional_api)
|
check_protocol_compliance(impl, additional_api)
|
||||||
|
|
||||||
return impl
|
return impl
|
||||||
|
|
|
@ -22,6 +22,13 @@ def get_impl_api(p: Any) -> Api:
|
||||||
|
|
||||||
async def register_object_with_provider(obj: RoutableObject, p: Any) -> None:
|
async def register_object_with_provider(obj: RoutableObject, p: Any) -> None:
|
||||||
api = get_impl_api(p)
|
api = get_impl_api(p)
|
||||||
|
|
||||||
|
if obj.provider_id == "remote":
|
||||||
|
# if this is just a passthrough, we want to let the remote
|
||||||
|
# end actually do the registration with the correct provider
|
||||||
|
obj = obj.model_copy(deep=True)
|
||||||
|
obj.provider_id = ""
|
||||||
|
|
||||||
if api == Api.inference:
|
if api == Api.inference:
|
||||||
await p.register_model(obj)
|
await p.register_model(obj)
|
||||||
elif api == Api.safety:
|
elif api == Api.safety:
|
||||||
|
@ -51,11 +58,22 @@ class CommonRoutingTableImpl(RoutingTable):
|
||||||
async def initialize(self) -> None:
|
async def initialize(self) -> None:
|
||||||
self.registry: Registry = {}
|
self.registry: Registry = {}
|
||||||
|
|
||||||
def add_objects(objs: List[RoutableObjectWithProvider]) -> None:
|
def add_objects(
|
||||||
|
objs: List[RoutableObjectWithProvider], provider_id: str, cls
|
||||||
|
) -> None:
|
||||||
for obj in objs:
|
for obj in objs:
|
||||||
if obj.identifier not in self.registry:
|
if obj.identifier not in self.registry:
|
||||||
self.registry[obj.identifier] = []
|
self.registry[obj.identifier] = []
|
||||||
|
|
||||||
|
if cls is None:
|
||||||
|
obj.provider_id = provider_id
|
||||||
|
else:
|
||||||
|
if provider_id == "remote":
|
||||||
|
# if this is just a passthrough, we got the *WithProvider object
|
||||||
|
# so we should just override the provider in-place
|
||||||
|
obj.provider_id = provider_id
|
||||||
|
else:
|
||||||
|
obj = cls(**obj.model_dump(), provider_id=provider_id)
|
||||||
self.registry[obj.identifier].append(obj)
|
self.registry[obj.identifier].append(obj)
|
||||||
|
|
||||||
for pid, p in self.impls_by_provider_id.items():
|
for pid, p in self.impls_by_provider_id.items():
|
||||||
|
@ -63,47 +81,27 @@ class CommonRoutingTableImpl(RoutingTable):
|
||||||
if api == Api.inference:
|
if api == Api.inference:
|
||||||
p.model_store = self
|
p.model_store = self
|
||||||
models = await p.list_models()
|
models = await p.list_models()
|
||||||
add_objects(
|
add_objects(models, pid, ModelDefWithProvider)
|
||||||
[ModelDefWithProvider(**m.dict(), provider_id=pid) for m in models]
|
|
||||||
)
|
|
||||||
|
|
||||||
elif api == Api.safety:
|
elif api == Api.safety:
|
||||||
p.shield_store = self
|
p.shield_store = self
|
||||||
shields = await p.list_shields()
|
shields = await p.list_shields()
|
||||||
add_objects(
|
add_objects(shields, pid, ShieldDefWithProvider)
|
||||||
[
|
|
||||||
ShieldDefWithProvider(**s.dict(), provider_id=pid)
|
|
||||||
for s in shields
|
|
||||||
]
|
|
||||||
)
|
|
||||||
|
|
||||||
elif api == Api.memory:
|
elif api == Api.memory:
|
||||||
p.memory_bank_store = self
|
p.memory_bank_store = self
|
||||||
memory_banks = await p.list_memory_banks()
|
memory_banks = await p.list_memory_banks()
|
||||||
|
add_objects(memory_banks, pid, None)
|
||||||
# do in-memory updates due to pesky Annotated unions
|
|
||||||
for m in memory_banks:
|
|
||||||
m.provider_id = pid
|
|
||||||
|
|
||||||
add_objects(memory_banks)
|
|
||||||
|
|
||||||
elif api == Api.datasetio:
|
elif api == Api.datasetio:
|
||||||
p.dataset_store = self
|
p.dataset_store = self
|
||||||
datasets = await p.list_datasets()
|
datasets = await p.list_datasets()
|
||||||
|
add_objects(datasets, pid, DatasetDefWithProvider)
|
||||||
# do in-memory updates due to pesky Annotated unions
|
|
||||||
for d in datasets:
|
|
||||||
d.provider_id = pid
|
|
||||||
|
|
||||||
elif api == Api.scoring:
|
elif api == Api.scoring:
|
||||||
p.scoring_function_store = self
|
p.scoring_function_store = self
|
||||||
scoring_functions = await p.list_scoring_functions()
|
scoring_functions = await p.list_scoring_functions()
|
||||||
add_objects(
|
add_objects(scoring_functions, pid, ScoringFnDefWithProvider)
|
||||||
[
|
|
||||||
ScoringFnDefWithProvider(**s.dict(), provider_id=pid)
|
|
||||||
for s in scoring_functions
|
|
||||||
]
|
|
||||||
)
|
|
||||||
|
|
||||||
async def shutdown(self) -> None:
|
async def shutdown(self) -> None:
|
||||||
for p in self.impls_by_provider_id.values():
|
for p in self.impls_by_provider_id.values():
|
||||||
|
|
|
@ -60,7 +60,7 @@ class MemoryBanksProtocolPrivate(Protocol):
|
||||||
class DatasetsProtocolPrivate(Protocol):
|
class DatasetsProtocolPrivate(Protocol):
|
||||||
async def list_datasets(self) -> List[DatasetDef]: ...
|
async def list_datasets(self) -> List[DatasetDef]: ...
|
||||||
|
|
||||||
async def register_datasets(self, dataset_def: DatasetDef) -> None: ...
|
async def register_dataset(self, dataset_def: DatasetDef) -> None: ...
|
||||||
|
|
||||||
|
|
||||||
class ScoringFunctionsProtocolPrivate(Protocol):
|
class ScoringFunctionsProtocolPrivate(Protocol):
|
||||||
|
@ -171,7 +171,7 @@ as being "Llama Stack compatible"
|
||||||
def module(self) -> str:
|
def module(self) -> str:
|
||||||
if self.adapter:
|
if self.adapter:
|
||||||
return self.adapter.module
|
return self.adapter.module
|
||||||
return f"llama_stack.apis.{self.api.value}.client"
|
return "llama_stack.distribution.client"
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def pip_packages(self) -> List[str]:
|
def pip_packages(self) -> List[str]:
|
||||||
|
|
|
@ -26,6 +26,7 @@ from dotenv import load_dotenv
|
||||||
#
|
#
|
||||||
# ```bash
|
# ```bash
|
||||||
# PROVIDER_ID=<your_provider> \
|
# PROVIDER_ID=<your_provider> \
|
||||||
|
# MODEL_ID=<your_model> \
|
||||||
# PROVIDER_CONFIG=provider_config.yaml \
|
# PROVIDER_CONFIG=provider_config.yaml \
|
||||||
# pytest -s llama_stack/providers/tests/agents/test_agents.py \
|
# pytest -s llama_stack/providers/tests/agents/test_agents.py \
|
||||||
# --tb=short --disable-warnings
|
# --tb=short --disable-warnings
|
||||||
|
@ -44,7 +45,7 @@ async def agents_settings():
|
||||||
"impl": impls[Api.agents],
|
"impl": impls[Api.agents],
|
||||||
"memory_impl": impls[Api.memory],
|
"memory_impl": impls[Api.memory],
|
||||||
"common_params": {
|
"common_params": {
|
||||||
"model": "Llama3.1-8B-Instruct",
|
"model": os.environ["MODEL_ID"] or "Llama3.1-8B-Instruct",
|
||||||
"instructions": "You are a helpful assistant.",
|
"instructions": "You are a helpful assistant.",
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
|
@ -3,7 +3,6 @@
|
||||||
#
|
#
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
import os
|
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
import pytest_asyncio
|
import pytest_asyncio
|
||||||
|
@ -73,7 +72,6 @@ async def register_memory_bank(banks_impl: MemoryBanks):
|
||||||
embedding_model="all-MiniLM-L6-v2",
|
embedding_model="all-MiniLM-L6-v2",
|
||||||
chunk_size_in_tokens=512,
|
chunk_size_in_tokens=512,
|
||||||
overlap_size_in_tokens=64,
|
overlap_size_in_tokens=64,
|
||||||
provider_id=os.environ["PROVIDER_ID"],
|
|
||||||
)
|
)
|
||||||
|
|
||||||
await banks_impl.register_memory_bank(bank)
|
await banks_impl.register_memory_bank(bank)
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue