Several smaller fixes to make adapters work

Also, reorganized the pattern of __init__ inside providers so
configuration can stay lightweight
This commit is contained in:
Ashwin Bharambe 2024-08-28 09:42:08 -07:00
parent 2a1552a5eb
commit 45987996c4
23 changed files with 164 additions and 160 deletions

View file

@ -6,11 +6,12 @@
import asyncio
from typing import Any
import fire
import httpx
from llama_models.llama3.api.datatypes import UserMessage
from pydantic import BaseModel
from termcolor import cprint
@ -19,7 +20,7 @@ from llama_toolchain.distribution.datatypes import RemoteProviderConfig
from .api import * # noqa: F403
async def get_provider_impl(config: RemoteProviderConfig) -> Safety:
async def get_client_impl(config: RemoteProviderConfig, _deps: Any) -> Safety:
return SafetyClient(config.url)

View file

@ -4,5 +4,14 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from .config import SafetyConfig # noqa
from .safety import get_provider_impl # noqa
from .config import SafetyConfig
async def get_provider_impl(config: SafetyConfig, _deps):
from .safety import MetaReferenceSafetyImpl
assert isinstance(config, SafetyConfig), f"Unexpected config type: {type(config)}"
impl = MetaReferenceSafetyImpl(config)
await impl.initialize()
return impl

View file

@ -5,12 +5,10 @@
# the root directory of this source tree.
import asyncio
from typing import Dict
from llama_models.sku_list import resolve_model
from llama_toolchain.common.model_utils import model_local_dir
from llama_toolchain.distribution.datatypes import Api, ProviderSpec
from llama_toolchain.safety.api import * # noqa
from .config import SafetyConfig
@ -25,14 +23,6 @@ from .shields import (
)
async def get_provider_impl(config: SafetyConfig, _deps: Dict[Api, ProviderSpec]):
assert isinstance(config, SafetyConfig), f"Unexpected config type: {type(config)}"
impl = MetaReferenceSafetyImpl(config)
await impl.initialize()
return impl
def resolve_and_get_path(model_name: str) -> str:
model = resolve_model(model_name)
assert model is not None, f"Could not resolve model {model_name}"