mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-07-29 15:23:51 +00:00
Several smaller fixes to make adapters work
Also, reorganized the pattern of __init__ inside providers so configuration can stay lightweight
This commit is contained in:
parent
2a1552a5eb
commit
45987996c4
23 changed files with 164 additions and 160 deletions
|
@ -194,11 +194,6 @@ async def run_rag(host: str, port: int):
|
|||
MemoryToolDefinition(
|
||||
max_tokens_in_context=2048,
|
||||
memory_bank_configs=[],
|
||||
# memory_bank_configs=[
|
||||
# AgenticSystemVectorMemoryBankConfig(
|
||||
# bank_id="970b8790-268e-4fd3-a9b1-d0e597e975ed",
|
||||
# )
|
||||
# ],
|
||||
),
|
||||
]
|
||||
|
||||
|
@ -210,8 +205,9 @@ async def run_rag(host: str, port: int):
|
|||
await _run_agent(api, tool_definitions, user_prompts, attachments)
|
||||
|
||||
|
||||
def main(host: str, port: int):
|
||||
asyncio.run(run_rag(host, port))
|
||||
def main(host: str, port: int, rag: bool = False):
|
||||
fn = run_rag if rag else run_main
|
||||
asyncio.run(fn(host, port))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
|
|
@ -4,5 +4,27 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from .agentic_system import get_provider_impl # noqa
|
||||
from .config import MetaReferenceImplConfig # noqa
|
||||
from typing import Dict
|
||||
|
||||
from llama_toolchain.distribution.datatypes import Api, ProviderSpec
|
||||
|
||||
from .config import MetaReferenceImplConfig
|
||||
|
||||
|
||||
async def get_provider_impl(
|
||||
config: MetaReferenceImplConfig, deps: Dict[Api, ProviderSpec]
|
||||
):
|
||||
from .agentic_system import MetaReferenceAgenticSystemImpl
|
||||
|
||||
assert isinstance(
|
||||
config, MetaReferenceImplConfig
|
||||
), f"Unexpected config type: {type(config)}"
|
||||
|
||||
impl = MetaReferenceAgenticSystemImpl(
|
||||
config,
|
||||
deps[Api.inference],
|
||||
deps[Api.memory],
|
||||
deps[Api.safety],
|
||||
)
|
||||
await impl.initialize()
|
||||
return impl
|
||||
|
|
|
@ -8,9 +8,8 @@
|
|||
import logging
|
||||
import os
|
||||
import uuid
|
||||
from typing import AsyncGenerator, Dict
|
||||
from typing import AsyncGenerator
|
||||
|
||||
from llama_toolchain.distribution.datatypes import Api, ProviderSpec
|
||||
from llama_toolchain.inference.api import Inference
|
||||
from llama_toolchain.memory.api import Memory
|
||||
from llama_toolchain.safety.api import Safety
|
||||
|
@ -31,23 +30,6 @@ logger = logging.getLogger()
|
|||
logger.setLevel(logging.INFO)
|
||||
|
||||
|
||||
async def get_provider_impl(
|
||||
config: MetaReferenceImplConfig, deps: Dict[Api, ProviderSpec]
|
||||
):
|
||||
assert isinstance(
|
||||
config, MetaReferenceImplConfig
|
||||
), f"Unexpected config type: {type(config)}"
|
||||
|
||||
impl = MetaReferenceAgenticSystemImpl(
|
||||
config,
|
||||
deps[Api.inference],
|
||||
deps[Api.memory],
|
||||
deps[Api.safety],
|
||||
)
|
||||
await impl.initialize()
|
||||
return impl
|
||||
|
||||
|
||||
AGENT_INSTANCES_BY_ID = {}
|
||||
|
||||
|
||||
|
|
|
@ -10,7 +10,7 @@ import os
|
|||
from pydantic import BaseModel
|
||||
from datetime import datetime
|
||||
from enum import Enum
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
import pkg_resources
|
||||
import yaml
|
||||
|
@ -37,26 +37,17 @@ def get_dependencies(
|
|||
) -> Dependencies:
|
||||
from llama_toolchain.distribution.distribution import SERVER_DEPENDENCIES
|
||||
|
||||
def _deps(provider: ProviderSpec) -> Tuple[List[str], Optional[str]]:
|
||||
if isinstance(provider, InlineProviderSpec):
|
||||
return provider.pip_packages, provider.docker_image
|
||||
else:
|
||||
if provider.adapter:
|
||||
return provider.adapter.pip_packages, None
|
||||
return [], None
|
||||
|
||||
pip_packages, docker_image = _deps(provider)
|
||||
pip_packages = provider.pip_packages
|
||||
for dep in dependencies.values():
|
||||
dep_pip_packages, dep_docker_image = _deps(dep)
|
||||
if docker_image and dep_docker_image:
|
||||
if dep.docker_image:
|
||||
raise ValueError(
|
||||
"You can only have the root provider specify a docker image"
|
||||
)
|
||||
|
||||
pip_packages.extend(dep_pip_packages)
|
||||
pip_packages.extend(dep.pip_packages)
|
||||
|
||||
return Dependencies(
|
||||
docker_image=docker_image, pip_packages=pip_packages + SERVER_DEPENDENCIES
|
||||
docker_image=provider.docker_image,
|
||||
pip_packages=pip_packages + SERVER_DEPENDENCIES
|
||||
)
|
||||
|
||||
|
||||
|
@ -158,6 +149,7 @@ class ApiBuild(Subcommand):
|
|||
build_dir = BUILDS_BASE_DIR / args.api
|
||||
os.makedirs(build_dir, exist_ok=True)
|
||||
|
||||
# get these names straight. too confusing.
|
||||
provider_deps = parse_dependencies(args.dependencies or "", self.parser)
|
||||
dependencies = get_dependencies(providers[args.provider], provider_deps)
|
||||
|
||||
|
@ -167,7 +159,7 @@ class ApiBuild(Subcommand):
|
|||
api.value: {
|
||||
"provider_id": args.provider,
|
||||
},
|
||||
**{k: {"provider_id": v} for k, v in provider_deps.items()},
|
||||
**provider_deps,
|
||||
}
|
||||
with open(package_file, "w") as f:
|
||||
c = PackageConfig(
|
||||
|
|
|
@ -48,7 +48,10 @@ class ApiConfigure(Subcommand):
|
|||
)
|
||||
|
||||
def _run_api_configure_cmd(self, args: argparse.Namespace) -> None:
|
||||
config_file = BUILDS_BASE_DIR / args.api / f"{args.name}.yaml"
|
||||
name = args.name
|
||||
if not name.endswith(".yaml"):
|
||||
name += ".yaml"
|
||||
config_file = BUILDS_BASE_DIR / args.api / name
|
||||
if not config_file.exists():
|
||||
self.parser.error(
|
||||
f"Could not find {config_file}. Please run `llama api build` first"
|
||||
|
@ -79,10 +82,19 @@ def configure_llama_provider(config_file: Path) -> None:
|
|||
)
|
||||
|
||||
provider_spec = providers[provider_id]
|
||||
cprint(f"Configuring API surface: {api}", "white", attrs=["bold"])
|
||||
cprint(
|
||||
f"Configuring API surface: {api} ({provider_id})", "white", attrs=["bold"]
|
||||
)
|
||||
config_type = instantiate_class_type(provider_spec.config_class)
|
||||
|
||||
try:
|
||||
existing_provider_config = config_type(**stub_config)
|
||||
except KeyError:
|
||||
existing_provider_config = None
|
||||
|
||||
provider_config = prompt_for_config(
|
||||
config_type,
|
||||
existing_provider_config,
|
||||
)
|
||||
print("")
|
||||
|
||||
|
|
|
@ -29,10 +29,9 @@ class ApiStart(Subcommand):
|
|||
|
||||
def _add_arguments(self):
|
||||
self.parser.add_argument(
|
||||
"--yaml-config",
|
||||
"yaml_config",
|
||||
type=str,
|
||||
help="Yaml config containing the API build configuration",
|
||||
required=True,
|
||||
)
|
||||
self.parser.add_argument(
|
||||
"--port",
|
||||
|
|
|
@ -69,7 +69,7 @@ ensure_conda_env_python310() {
|
|||
conda create -n "${env_name}" python="${python_version}" -y
|
||||
|
||||
ENVNAME="${env_name}"
|
||||
setup_cleanup_handlers
|
||||
# setup_cleanup_handlers
|
||||
fi
|
||||
|
||||
eval "$(conda shell.bash hook)"
|
||||
|
|
|
@ -36,16 +36,14 @@ class ProviderSpec(BaseModel):
|
|||
...,
|
||||
description="Fully-qualified classname of the config for this provider",
|
||||
)
|
||||
api_dependencies: List[Api] = Field(
|
||||
default_factory=list,
|
||||
description="Higher-level API surfaces may depend on other providers to provide their functionality",
|
||||
)
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class AdapterSpec(BaseModel):
|
||||
"""
|
||||
If some code is needed to convert the remote responses into Llama Stack compatible
|
||||
API responses, specify the adapter here. If not specified, it indicates the remote
|
||||
as being "Llama Stack compatible"
|
||||
"""
|
||||
|
||||
adapter_id: str = Field(
|
||||
...,
|
||||
description="Unique identifier for this adapter",
|
||||
|
@ -89,11 +87,6 @@ Fully-qualified name of the module to import. The module is expected to have:
|
|||
- `get_provider_impl(config, deps)`: returns the local implementation
|
||||
""",
|
||||
)
|
||||
api_dependencies: List[Api] = Field(
|
||||
default_factory=list,
|
||||
description="Higher-level API surfaces may depend on other providers to provide their functionality",
|
||||
)
|
||||
is_adapter: bool = False
|
||||
|
||||
|
||||
class RemoteProviderConfig(BaseModel):
|
||||
|
@ -113,34 +106,41 @@ def remote_provider_id(adapter_id: str) -> str:
|
|||
|
||||
@json_schema_type
|
||||
class RemoteProviderSpec(ProviderSpec):
|
||||
provider_id: str = "remote"
|
||||
config_class: str = "llama_toolchain.distribution.datatypes.RemoteProviderConfig"
|
||||
adapter: Optional[AdapterSpec] = Field(
|
||||
default=None,
|
||||
description="""
|
||||
If some code is needed to convert the remote responses into Llama Stack compatible
|
||||
API responses, specify the adapter here. If not specified, it indicates the remote
|
||||
as being "Llama Stack compatible"
|
||||
""",
|
||||
)
|
||||
|
||||
@property
|
||||
def module(self) -> str:
|
||||
if self.adapter:
|
||||
return self.adapter.module
|
||||
return f"llama_toolchain.{self.api.value}.client"
|
||||
|
||||
|
||||
def remote_provider_spec(api: Api) -> RemoteProviderSpec:
|
||||
return RemoteProviderSpec(api=api)
|
||||
@property
|
||||
def pip_packages(self) -> List[str]:
|
||||
if self.adapter:
|
||||
return self.adapter.pip_packages
|
||||
return []
|
||||
|
||||
|
||||
# TODO: use computed_field to avoid this wrapper
|
||||
# the @computed_field decorator
|
||||
def adapter_provider_spec(api: Api, adapter: AdapterSpec) -> InlineProviderSpec:
|
||||
# Can avoid this by using Pydantic computed_field
|
||||
def remote_provider_spec(
|
||||
api: Api, adapter: Optional[AdapterSpec] = None
|
||||
) -> RemoteProviderSpec:
|
||||
config_class = (
|
||||
adapter.config_class
|
||||
if adapter.config_class
|
||||
if adapter and adapter.config_class
|
||||
else "llama_toolchain.distribution.datatypes.RemoteProviderConfig"
|
||||
)
|
||||
provider_id = remote_provider_id(adapter.adapter_id) if adapter else "remote"
|
||||
|
||||
return InlineProviderSpec(
|
||||
api=api,
|
||||
provider_id=remote_provider_id(adapter.adapter_id),
|
||||
pip_packages=adapter.pip_packages,
|
||||
module=adapter.module,
|
||||
config_class=config_class,
|
||||
is_adapter=True,
|
||||
return RemoteProviderSpec(
|
||||
api=api, provider_id=provider_id, config_class=config_class, adapter=adapter
|
||||
)
|
||||
|
||||
|
||||
|
|
|
@ -22,6 +22,7 @@ from .datatypes import (
|
|||
DistributionSpec,
|
||||
InlineProviderSpec,
|
||||
ProviderSpec,
|
||||
remote_provider_spec,
|
||||
)
|
||||
|
||||
# These are the dependencies needed by the distribution server.
|
||||
|
@ -89,9 +90,12 @@ def api_providers() -> Dict[Api, Dict[str, ProviderSpec]]:
|
|||
a.provider_id: a for a in available_agentic_system_providers()
|
||||
}
|
||||
|
||||
return {
|
||||
ret = {
|
||||
Api.inference: inference_providers_by_id,
|
||||
Api.safety: safety_providers_by_id,
|
||||
Api.agentic_system: agentic_system_providers_by_id,
|
||||
Api.memory: {a.provider_id: a for a in available_memory_providers()},
|
||||
}
|
||||
for k, v in ret.items():
|
||||
v["remote"] = remote_provider_spec(k)
|
||||
return ret
|
||||
|
|
|
@ -8,7 +8,7 @@ import asyncio
|
|||
import importlib
|
||||
from typing import Any, Dict
|
||||
|
||||
from .datatypes import InlineProviderSpec, ProviderSpec, RemoteProviderConfig
|
||||
from .datatypes import ProviderSpec, RemoteProviderConfig, RemoteProviderSpec
|
||||
|
||||
|
||||
def instantiate_class_type(fully_qualified_name):
|
||||
|
@ -26,16 +26,21 @@ def instantiate_provider(
|
|||
module = importlib.import_module(provider_spec.module)
|
||||
|
||||
config_type = instantiate_class_type(provider_spec.config_class)
|
||||
if isinstance(provider_spec, InlineProviderSpec):
|
||||
if provider_spec.is_adapter:
|
||||
if isinstance(provider_spec, RemoteProviderSpec):
|
||||
if provider_spec.adapter:
|
||||
if not issubclass(config_type, RemoteProviderConfig):
|
||||
raise ValueError(
|
||||
f"Config class {provider_spec.config_class} does not inherit from RemoteProviderConfig"
|
||||
)
|
||||
config = config_type(**provider_config)
|
||||
|
||||
if isinstance(provider_spec, InlineProviderSpec):
|
||||
args = [config, deps]
|
||||
method = "get_adapter_impl"
|
||||
else:
|
||||
method = "get_client_impl"
|
||||
else:
|
||||
args = [config]
|
||||
return asyncio.run(module.get_provider_impl(*args))
|
||||
method = "get_provider_impl"
|
||||
|
||||
config = config_type(**provider_config)
|
||||
fn = getattr(module, method)
|
||||
impl = asyncio.run(fn(config, deps))
|
||||
impl.__provider_spec__ = provider_spec
|
||||
impl.__provider_config__ = config
|
||||
return impl
|
||||
|
|
|
@ -38,7 +38,7 @@ from pydantic import BaseModel, ValidationError
|
|||
from termcolor import cprint
|
||||
from typing_extensions import Annotated
|
||||
|
||||
from .datatypes import Api, ProviderSpec, RemoteProviderSpec
|
||||
from .datatypes import Api, InlineProviderSpec, ProviderSpec, RemoteProviderSpec
|
||||
from .distribution import api_endpoints, api_providers
|
||||
from .dynamic import instantiate_provider
|
||||
|
||||
|
@ -230,10 +230,9 @@ def topological_sort(providers: List[ProviderSpec]) -> List[ProviderSpec]:
|
|||
def dfs(a: ProviderSpec, visited: Set[Api], stack: List[Api]):
|
||||
visited.add(a.api)
|
||||
|
||||
if not isinstance(a, RemoteProviderSpec):
|
||||
for api in a.api_dependencies:
|
||||
if api not in visited:
|
||||
dfs(by_id[api], visited, stack)
|
||||
for api in a.api_dependencies:
|
||||
if api not in visited:
|
||||
dfs(by_id[api], visited, stack)
|
||||
|
||||
stack.append(a.api)
|
||||
|
||||
|
@ -261,7 +260,10 @@ def resolve_impls(
|
|||
f"Could not find provider_spec config for {api}. Please add it to the config"
|
||||
)
|
||||
|
||||
deps = {api: impls[api] for api in provider_spec.api_dependencies}
|
||||
if isinstance(provider_spec, InlineProviderSpec):
|
||||
deps = {api: impls[api] for api in provider_spec.api_dependencies}
|
||||
else:
|
||||
deps = {}
|
||||
provider_config = provider_configs[api.value]
|
||||
impl = instantiate_provider(provider_spec, provider_config, deps)
|
||||
impls[api] = impl
|
||||
|
@ -302,7 +304,7 @@ def main(yaml_config: str, port: int = 5000, disable_ipv6: bool = False):
|
|||
and provider_spec.adapter is None
|
||||
):
|
||||
for endpoint in endpoints:
|
||||
url = impl.base_url + endpoint.route
|
||||
url = impl.__provider_config__.url
|
||||
getattr(app, endpoint.method)(endpoint.route)(
|
||||
create_dynamic_passthrough(url)
|
||||
)
|
||||
|
|
|
@ -4,4 +4,12 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from .ollama import get_provider_impl # noqa
|
||||
from llama_toolchain.distribution.datatypes import RemoteProviderConfig
|
||||
|
||||
|
||||
async def get_adapter_impl(config: RemoteProviderConfig, _deps):
|
||||
from .ollama import OllamaInferenceAdapter
|
||||
|
||||
impl = OllamaInferenceAdapter(config.url)
|
||||
await impl.initialize()
|
||||
return impl
|
||||
|
|
|
@ -4,7 +4,7 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import Any, AsyncGenerator
|
||||
from typing import AsyncGenerator
|
||||
|
||||
import httpx
|
||||
|
||||
|
@ -14,34 +14,18 @@ from llama_models.llama3.api.tokenizer import Tokenizer
|
|||
from llama_models.sku_list import resolve_model
|
||||
from ollama import AsyncClient
|
||||
|
||||
from llama_toolchain.distribution.datatypes import RemoteProviderConfig
|
||||
from llama_toolchain.inference.api import (
|
||||
ChatCompletionRequest,
|
||||
ChatCompletionResponse,
|
||||
ChatCompletionResponseEvent,
|
||||
ChatCompletionResponseEventType,
|
||||
ChatCompletionResponseStreamChunk,
|
||||
CompletionRequest,
|
||||
Inference,
|
||||
ToolCallDelta,
|
||||
ToolCallParseStatus,
|
||||
)
|
||||
from llama_toolchain.inference.api import * # noqa: F403
|
||||
from llama_toolchain.inference.prepare_messages import prepare_messages
|
||||
|
||||
# TODO: Eventually this will move to the llama cli model list command
|
||||
# mapping of Model SKUs to ollama models
|
||||
OLLAMA_SUPPORTED_SKUS = {
|
||||
# "Meta-Llama3.1-8B-Instruct": "llama3.1",
|
||||
"Meta-Llama3.1-8B-Instruct": "llama3.1:8b-instruct-fp16",
|
||||
"Meta-Llama3.1-70B-Instruct": "llama3.1:70b-instruct-fp16",
|
||||
}
|
||||
|
||||
|
||||
async def get_provider_impl(config: RemoteProviderConfig, _deps: Any) -> Inference:
|
||||
impl = OllamaInferenceAdapter(config.url)
|
||||
await impl.initialize()
|
||||
return impl
|
||||
|
||||
|
||||
class OllamaInferenceAdapter(Inference):
|
||||
def __init__(self, url: str) -> None:
|
||||
self.url = url
|
||||
|
|
|
@ -6,7 +6,7 @@
|
|||
|
||||
import asyncio
|
||||
import json
|
||||
from typing import AsyncGenerator
|
||||
from typing import Any, AsyncGenerator
|
||||
|
||||
import fire
|
||||
import httpx
|
||||
|
@ -26,7 +26,7 @@ from .api import (
|
|||
from .event_logger import EventLogger
|
||||
|
||||
|
||||
async def get_provider_impl(config: RemoteProviderConfig) -> Inference:
|
||||
async def get_client_impl(config: RemoteProviderConfig, _deps: Any) -> Inference:
|
||||
return InferenceClient(config.url)
|
||||
|
||||
|
||||
|
|
|
@ -5,4 +5,15 @@
|
|||
# the root directory of this source tree.
|
||||
|
||||
from .config import MetaReferenceImplConfig # noqa
|
||||
from .inference import get_provider_impl # noqa
|
||||
|
||||
|
||||
async def get_provider_impl(config: MetaReferenceImplConfig, _deps):
|
||||
from .inference import MetaReferenceInferenceImpl
|
||||
|
||||
assert isinstance(
|
||||
config, MetaReferenceImplConfig
|
||||
), f"Unexpected config type: {type(config)}"
|
||||
|
||||
impl = MetaReferenceInferenceImpl(config)
|
||||
await impl.initialize()
|
||||
return impl
|
||||
|
|
|
@ -6,12 +6,11 @@
|
|||
|
||||
import asyncio
|
||||
|
||||
from typing import AsyncIterator, Dict, Union
|
||||
from typing import AsyncIterator, Union
|
||||
|
||||
from llama_models.llama3.api.datatypes import StopReason
|
||||
from llama_models.sku_list import resolve_model
|
||||
|
||||
from llama_toolchain.distribution.datatypes import Api, ProviderSpec
|
||||
from llama_toolchain.inference.api import (
|
||||
ChatCompletionRequest,
|
||||
ChatCompletionResponse,
|
||||
|
@ -27,18 +26,6 @@ from .config import MetaReferenceImplConfig
|
|||
from .model_parallel import LlamaModelParallelGenerator
|
||||
|
||||
|
||||
async def get_provider_impl(
|
||||
config: MetaReferenceImplConfig, _deps: Dict[Api, ProviderSpec]
|
||||
):
|
||||
assert isinstance(
|
||||
config, MetaReferenceImplConfig
|
||||
), f"Unexpected config type: {type(config)}"
|
||||
|
||||
impl = MetaReferenceInferenceImpl(config)
|
||||
await impl.initialize()
|
||||
return impl
|
||||
|
||||
|
||||
# there's a single model parallel process running serving the model. for now,
|
||||
# we don't support multiple concurrent requests to this process.
|
||||
SEMAPHORE = asyncio.Semaphore(1)
|
||||
|
|
|
@ -27,7 +27,7 @@ def available_inference_providers() -> List[ProviderSpec]:
|
|||
module="llama_toolchain.inference.meta_reference",
|
||||
config_class="llama_toolchain.inference.meta_reference.MetaReferenceImplConfig",
|
||||
),
|
||||
adapter_provider_spec(
|
||||
remote_provider_spec(
|
||||
api=Api.inference,
|
||||
adapter=AdapterSpec(
|
||||
adapter_id="ollama",
|
||||
|
|
|
@ -6,7 +6,7 @@
|
|||
|
||||
import asyncio
|
||||
|
||||
from typing import Dict, List, Optional
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
import fire
|
||||
import httpx
|
||||
|
@ -16,7 +16,7 @@ from llama_toolchain.distribution.datatypes import RemoteProviderConfig
|
|||
from .api import * # noqa: F403
|
||||
|
||||
|
||||
async def get_provider_impl(config: RemoteProviderConfig) -> Memory:
|
||||
async def get_client_impl(config: RemoteProviderConfig, _deps: Any) -> Memory:
|
||||
return MemoryClient(config.url)
|
||||
|
||||
|
||||
|
|
|
@ -4,5 +4,16 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from .config import FaissImplConfig # noqa
|
||||
from .faiss import get_provider_impl # noqa
|
||||
from .config import FaissImplConfig
|
||||
|
||||
|
||||
async def get_provider_impl(config: FaissImplConfig, _deps):
|
||||
from .faiss import FaissMemoryImpl
|
||||
|
||||
assert isinstance(
|
||||
config, FaissImplConfig
|
||||
), f"Unexpected config type: {type(config)}"
|
||||
|
||||
impl = FaissMemoryImpl(config)
|
||||
await impl.initialize()
|
||||
return impl
|
||||
|
|
|
@ -15,21 +15,10 @@ import numpy as np
|
|||
from llama_models.llama3.api.datatypes import * # noqa: F403
|
||||
from llama_models.llama3.api.tokenizer import Tokenizer
|
||||
|
||||
from llama_toolchain.distribution.datatypes import Api, ProviderSpec
|
||||
from llama_toolchain.memory.api import * # noqa: F403
|
||||
from .config import FaissImplConfig
|
||||
|
||||
|
||||
async def get_provider_impl(config: FaissImplConfig, _deps: Dict[Api, ProviderSpec]):
|
||||
assert isinstance(
|
||||
config, FaissImplConfig
|
||||
), f"Unexpected config type: {type(config)}"
|
||||
|
||||
impl = FaissMemoryImpl(config)
|
||||
await impl.initialize()
|
||||
return impl
|
||||
|
||||
|
||||
async def content_from_doc(doc: MemoryBankDocument) -> str:
|
||||
if isinstance(doc.content, URL):
|
||||
async with httpx.AsyncClient() as client:
|
||||
|
|
|
@ -6,11 +6,12 @@
|
|||
|
||||
import asyncio
|
||||
|
||||
from typing import Any
|
||||
|
||||
import fire
|
||||
import httpx
|
||||
|
||||
from llama_models.llama3.api.datatypes import UserMessage
|
||||
|
||||
from pydantic import BaseModel
|
||||
from termcolor import cprint
|
||||
|
||||
|
@ -19,7 +20,7 @@ from llama_toolchain.distribution.datatypes import RemoteProviderConfig
|
|||
from .api import * # noqa: F403
|
||||
|
||||
|
||||
async def get_provider_impl(config: RemoteProviderConfig) -> Safety:
|
||||
async def get_client_impl(config: RemoteProviderConfig, _deps: Any) -> Safety:
|
||||
return SafetyClient(config.url)
|
||||
|
||||
|
||||
|
|
|
@ -4,5 +4,14 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from .config import SafetyConfig # noqa
|
||||
from .safety import get_provider_impl # noqa
|
||||
from .config import SafetyConfig
|
||||
|
||||
|
||||
async def get_provider_impl(config: SafetyConfig, _deps):
|
||||
from .safety import MetaReferenceSafetyImpl
|
||||
|
||||
assert isinstance(config, SafetyConfig), f"Unexpected config type: {type(config)}"
|
||||
|
||||
impl = MetaReferenceSafetyImpl(config)
|
||||
await impl.initialize()
|
||||
return impl
|
||||
|
|
|
@ -5,12 +5,10 @@
|
|||
# the root directory of this source tree.
|
||||
|
||||
import asyncio
|
||||
from typing import Dict
|
||||
|
||||
from llama_models.sku_list import resolve_model
|
||||
|
||||
from llama_toolchain.common.model_utils import model_local_dir
|
||||
from llama_toolchain.distribution.datatypes import Api, ProviderSpec
|
||||
from llama_toolchain.safety.api import * # noqa
|
||||
|
||||
from .config import SafetyConfig
|
||||
|
@ -25,14 +23,6 @@ from .shields import (
|
|||
)
|
||||
|
||||
|
||||
async def get_provider_impl(config: SafetyConfig, _deps: Dict[Api, ProviderSpec]):
|
||||
assert isinstance(config, SafetyConfig), f"Unexpected config type: {type(config)}"
|
||||
|
||||
impl = MetaReferenceSafetyImpl(config)
|
||||
await impl.initialize()
|
||||
return impl
|
||||
|
||||
|
||||
def resolve_and_get_path(model_name: str) -> str:
|
||||
model = resolve_model(model_name)
|
||||
assert model is not None, f"Could not resolve model {model_name}"
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue