Several smaller fixes to make adapters work

Also, reorganized the pattern of __init__ inside providers so
configuration can stay lightweight
This commit is contained in:
Ashwin Bharambe 2024-08-28 09:42:08 -07:00
parent 2a1552a5eb
commit 45987996c4
23 changed files with 164 additions and 160 deletions

View file

@ -194,11 +194,6 @@ async def run_rag(host: str, port: int):
MemoryToolDefinition(
max_tokens_in_context=2048,
memory_bank_configs=[],
# memory_bank_configs=[
# AgenticSystemVectorMemoryBankConfig(
# bank_id="970b8790-268e-4fd3-a9b1-d0e597e975ed",
# )
# ],
),
]
@ -210,8 +205,9 @@ async def run_rag(host: str, port: int):
await _run_agent(api, tool_definitions, user_prompts, attachments)
def main(host: str, port: int):
asyncio.run(run_rag(host, port))
def main(host: str, port: int, rag: bool = False):
fn = run_rag if rag else run_main
asyncio.run(fn(host, port))
if __name__ == "__main__":

View file

@ -4,5 +4,27 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from .agentic_system import get_provider_impl # noqa
from .config import MetaReferenceImplConfig # noqa
from typing import Dict
from llama_toolchain.distribution.datatypes import Api, ProviderSpec
from .config import MetaReferenceImplConfig
async def get_provider_impl(
config: MetaReferenceImplConfig, deps: Dict[Api, ProviderSpec]
):
from .agentic_system import MetaReferenceAgenticSystemImpl
assert isinstance(
config, MetaReferenceImplConfig
), f"Unexpected config type: {type(config)}"
impl = MetaReferenceAgenticSystemImpl(
config,
deps[Api.inference],
deps[Api.memory],
deps[Api.safety],
)
await impl.initialize()
return impl

View file

@ -8,9 +8,8 @@
import logging
import os
import uuid
from typing import AsyncGenerator, Dict
from typing import AsyncGenerator
from llama_toolchain.distribution.datatypes import Api, ProviderSpec
from llama_toolchain.inference.api import Inference
from llama_toolchain.memory.api import Memory
from llama_toolchain.safety.api import Safety
@ -31,23 +30,6 @@ logger = logging.getLogger()
logger.setLevel(logging.INFO)
async def get_provider_impl(
config: MetaReferenceImplConfig, deps: Dict[Api, ProviderSpec]
):
assert isinstance(
config, MetaReferenceImplConfig
), f"Unexpected config type: {type(config)}"
impl = MetaReferenceAgenticSystemImpl(
config,
deps[Api.inference],
deps[Api.memory],
deps[Api.safety],
)
await impl.initialize()
return impl
AGENT_INSTANCES_BY_ID = {}

View file

@ -10,7 +10,7 @@ import os
from pydantic import BaseModel
from datetime import datetime
from enum import Enum
from typing import Dict, List, Optional, Tuple
from typing import Dict, List, Optional
import pkg_resources
import yaml
@ -37,26 +37,17 @@ def get_dependencies(
) -> Dependencies:
from llama_toolchain.distribution.distribution import SERVER_DEPENDENCIES
def _deps(provider: ProviderSpec) -> Tuple[List[str], Optional[str]]:
if isinstance(provider, InlineProviderSpec):
return provider.pip_packages, provider.docker_image
else:
if provider.adapter:
return provider.adapter.pip_packages, None
return [], None
pip_packages, docker_image = _deps(provider)
pip_packages = provider.pip_packages
for dep in dependencies.values():
dep_pip_packages, dep_docker_image = _deps(dep)
if docker_image and dep_docker_image:
if dep.docker_image:
raise ValueError(
"You can only have the root provider specify a docker image"
)
pip_packages.extend(dep_pip_packages)
pip_packages.extend(dep.pip_packages)
return Dependencies(
docker_image=docker_image, pip_packages=pip_packages + SERVER_DEPENDENCIES
docker_image=provider.docker_image,
pip_packages=pip_packages + SERVER_DEPENDENCIES
)
@ -158,6 +149,7 @@ class ApiBuild(Subcommand):
build_dir = BUILDS_BASE_DIR / args.api
os.makedirs(build_dir, exist_ok=True)
# get these names straight. too confusing.
provider_deps = parse_dependencies(args.dependencies or "", self.parser)
dependencies = get_dependencies(providers[args.provider], provider_deps)
@ -167,7 +159,7 @@ class ApiBuild(Subcommand):
api.value: {
"provider_id": args.provider,
},
**{k: {"provider_id": v} for k, v in provider_deps.items()},
**provider_deps,
}
with open(package_file, "w") as f:
c = PackageConfig(

View file

@ -48,7 +48,10 @@ class ApiConfigure(Subcommand):
)
def _run_api_configure_cmd(self, args: argparse.Namespace) -> None:
config_file = BUILDS_BASE_DIR / args.api / f"{args.name}.yaml"
name = args.name
if not name.endswith(".yaml"):
name += ".yaml"
config_file = BUILDS_BASE_DIR / args.api / name
if not config_file.exists():
self.parser.error(
f"Could not find {config_file}. Please run `llama api build` first"
@ -79,10 +82,19 @@ def configure_llama_provider(config_file: Path) -> None:
)
provider_spec = providers[provider_id]
cprint(f"Configuring API surface: {api}", "white", attrs=["bold"])
cprint(
f"Configuring API surface: {api} ({provider_id})", "white", attrs=["bold"]
)
config_type = instantiate_class_type(provider_spec.config_class)
try:
existing_provider_config = config_type(**stub_config)
except KeyError:
existing_provider_config = None
provider_config = prompt_for_config(
config_type,
existing_provider_config,
)
print("")

View file

@ -29,10 +29,9 @@ class ApiStart(Subcommand):
def _add_arguments(self):
self.parser.add_argument(
"--yaml-config",
"yaml_config",
type=str,
help="Yaml config containing the API build configuration",
required=True,
)
self.parser.add_argument(
"--port",

View file

@ -69,7 +69,7 @@ ensure_conda_env_python310() {
conda create -n "${env_name}" python="${python_version}" -y
ENVNAME="${env_name}"
setup_cleanup_handlers
# setup_cleanup_handlers
fi
eval "$(conda shell.bash hook)"

View file

@ -36,16 +36,14 @@ class ProviderSpec(BaseModel):
...,
description="Fully-qualified classname of the config for this provider",
)
api_dependencies: List[Api] = Field(
default_factory=list,
description="Higher-level API surfaces may depend on other providers to provide their functionality",
)
@json_schema_type
class AdapterSpec(BaseModel):
"""
If some code is needed to convert the remote responses into Llama Stack compatible
API responses, specify the adapter here. If not specified, it indicates the remote
as being "Llama Stack compatible"
"""
adapter_id: str = Field(
...,
description="Unique identifier for this adapter",
@ -89,11 +87,6 @@ Fully-qualified name of the module to import. The module is expected to have:
- `get_provider_impl(config, deps)`: returns the local implementation
""",
)
api_dependencies: List[Api] = Field(
default_factory=list,
description="Higher-level API surfaces may depend on other providers to provide their functionality",
)
is_adapter: bool = False
class RemoteProviderConfig(BaseModel):
@ -113,34 +106,41 @@ def remote_provider_id(adapter_id: str) -> str:
@json_schema_type
class RemoteProviderSpec(ProviderSpec):
provider_id: str = "remote"
config_class: str = "llama_toolchain.distribution.datatypes.RemoteProviderConfig"
adapter: Optional[AdapterSpec] = Field(
default=None,
description="""
If some code is needed to convert the remote responses into Llama Stack compatible
API responses, specify the adapter here. If not specified, it indicates the remote
as being "Llama Stack compatible"
""",
)
@property
def module(self) -> str:
if self.adapter:
return self.adapter.module
return f"llama_toolchain.{self.api.value}.client"
def remote_provider_spec(api: Api) -> RemoteProviderSpec:
return RemoteProviderSpec(api=api)
@property
def pip_packages(self) -> List[str]:
if self.adapter:
return self.adapter.pip_packages
return []
# TODO: use computed_field to avoid this wrapper
# the @computed_field decorator
def adapter_provider_spec(api: Api, adapter: AdapterSpec) -> InlineProviderSpec:
# Can avoid this by using Pydantic computed_field
def remote_provider_spec(
api: Api, adapter: Optional[AdapterSpec] = None
) -> RemoteProviderSpec:
config_class = (
adapter.config_class
if adapter.config_class
if adapter and adapter.config_class
else "llama_toolchain.distribution.datatypes.RemoteProviderConfig"
)
provider_id = remote_provider_id(adapter.adapter_id) if adapter else "remote"
return InlineProviderSpec(
api=api,
provider_id=remote_provider_id(adapter.adapter_id),
pip_packages=adapter.pip_packages,
module=adapter.module,
config_class=config_class,
is_adapter=True,
return RemoteProviderSpec(
api=api, provider_id=provider_id, config_class=config_class, adapter=adapter
)

View file

@ -22,6 +22,7 @@ from .datatypes import (
DistributionSpec,
InlineProviderSpec,
ProviderSpec,
remote_provider_spec,
)
# These are the dependencies needed by the distribution server.
@ -89,9 +90,12 @@ def api_providers() -> Dict[Api, Dict[str, ProviderSpec]]:
a.provider_id: a for a in available_agentic_system_providers()
}
return {
ret = {
Api.inference: inference_providers_by_id,
Api.safety: safety_providers_by_id,
Api.agentic_system: agentic_system_providers_by_id,
Api.memory: {a.provider_id: a for a in available_memory_providers()},
}
for k, v in ret.items():
v["remote"] = remote_provider_spec(k)
return ret

View file

@ -8,7 +8,7 @@ import asyncio
import importlib
from typing import Any, Dict
from .datatypes import InlineProviderSpec, ProviderSpec, RemoteProviderConfig
from .datatypes import ProviderSpec, RemoteProviderConfig, RemoteProviderSpec
def instantiate_class_type(fully_qualified_name):
@ -26,16 +26,21 @@ def instantiate_provider(
module = importlib.import_module(provider_spec.module)
config_type = instantiate_class_type(provider_spec.config_class)
if isinstance(provider_spec, InlineProviderSpec):
if provider_spec.is_adapter:
if isinstance(provider_spec, RemoteProviderSpec):
if provider_spec.adapter:
if not issubclass(config_type, RemoteProviderConfig):
raise ValueError(
f"Config class {provider_spec.config_class} does not inherit from RemoteProviderConfig"
)
config = config_type(**provider_config)
if isinstance(provider_spec, InlineProviderSpec):
args = [config, deps]
method = "get_adapter_impl"
else:
args = [config]
return asyncio.run(module.get_provider_impl(*args))
method = "get_client_impl"
else:
method = "get_provider_impl"
config = config_type(**provider_config)
fn = getattr(module, method)
impl = asyncio.run(fn(config, deps))
impl.__provider_spec__ = provider_spec
impl.__provider_config__ = config
return impl

View file

@ -38,7 +38,7 @@ from pydantic import BaseModel, ValidationError
from termcolor import cprint
from typing_extensions import Annotated
from .datatypes import Api, ProviderSpec, RemoteProviderSpec
from .datatypes import Api, InlineProviderSpec, ProviderSpec, RemoteProviderSpec
from .distribution import api_endpoints, api_providers
from .dynamic import instantiate_provider
@ -230,7 +230,6 @@ def topological_sort(providers: List[ProviderSpec]) -> List[ProviderSpec]:
def dfs(a: ProviderSpec, visited: Set[Api], stack: List[Api]):
visited.add(a.api)
if not isinstance(a, RemoteProviderSpec):
for api in a.api_dependencies:
if api not in visited:
dfs(by_id[api], visited, stack)
@ -261,7 +260,10 @@ def resolve_impls(
f"Could not find provider_spec config for {api}. Please add it to the config"
)
if isinstance(provider_spec, InlineProviderSpec):
deps = {api: impls[api] for api in provider_spec.api_dependencies}
else:
deps = {}
provider_config = provider_configs[api.value]
impl = instantiate_provider(provider_spec, provider_config, deps)
impls[api] = impl
@ -302,7 +304,7 @@ def main(yaml_config: str, port: int = 5000, disable_ipv6: bool = False):
and provider_spec.adapter is None
):
for endpoint in endpoints:
url = impl.base_url + endpoint.route
url = impl.__provider_config__.url
getattr(app, endpoint.method)(endpoint.route)(
create_dynamic_passthrough(url)
)

View file

@ -4,4 +4,12 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from .ollama import get_provider_impl # noqa
from llama_toolchain.distribution.datatypes import RemoteProviderConfig
async def get_adapter_impl(config: RemoteProviderConfig, _deps):
from .ollama import OllamaInferenceAdapter
impl = OllamaInferenceAdapter(config.url)
await impl.initialize()
return impl

View file

@ -4,7 +4,7 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Any, AsyncGenerator
from typing import AsyncGenerator
import httpx
@ -14,34 +14,18 @@ from llama_models.llama3.api.tokenizer import Tokenizer
from llama_models.sku_list import resolve_model
from ollama import AsyncClient
from llama_toolchain.distribution.datatypes import RemoteProviderConfig
from llama_toolchain.inference.api import (
ChatCompletionRequest,
ChatCompletionResponse,
ChatCompletionResponseEvent,
ChatCompletionResponseEventType,
ChatCompletionResponseStreamChunk,
CompletionRequest,
Inference,
ToolCallDelta,
ToolCallParseStatus,
)
from llama_toolchain.inference.api import * # noqa: F403
from llama_toolchain.inference.prepare_messages import prepare_messages
# TODO: Eventually this will move to the llama cli model list command
# mapping of Model SKUs to ollama models
OLLAMA_SUPPORTED_SKUS = {
# "Meta-Llama3.1-8B-Instruct": "llama3.1",
"Meta-Llama3.1-8B-Instruct": "llama3.1:8b-instruct-fp16",
"Meta-Llama3.1-70B-Instruct": "llama3.1:70b-instruct-fp16",
}
async def get_provider_impl(config: RemoteProviderConfig, _deps: Any) -> Inference:
impl = OllamaInferenceAdapter(config.url)
await impl.initialize()
return impl
class OllamaInferenceAdapter(Inference):
def __init__(self, url: str) -> None:
self.url = url

View file

@ -6,7 +6,7 @@
import asyncio
import json
from typing import AsyncGenerator
from typing import Any, AsyncGenerator
import fire
import httpx
@ -26,7 +26,7 @@ from .api import (
from .event_logger import EventLogger
async def get_provider_impl(config: RemoteProviderConfig) -> Inference:
async def get_client_impl(config: RemoteProviderConfig, _deps: Any) -> Inference:
return InferenceClient(config.url)

View file

@ -5,4 +5,15 @@
# the root directory of this source tree.
from .config import MetaReferenceImplConfig # noqa
from .inference import get_provider_impl # noqa
async def get_provider_impl(config: MetaReferenceImplConfig, _deps):
from .inference import MetaReferenceInferenceImpl
assert isinstance(
config, MetaReferenceImplConfig
), f"Unexpected config type: {type(config)}"
impl = MetaReferenceInferenceImpl(config)
await impl.initialize()
return impl

View file

@ -6,12 +6,11 @@
import asyncio
from typing import AsyncIterator, Dict, Union
from typing import AsyncIterator, Union
from llama_models.llama3.api.datatypes import StopReason
from llama_models.sku_list import resolve_model
from llama_toolchain.distribution.datatypes import Api, ProviderSpec
from llama_toolchain.inference.api import (
ChatCompletionRequest,
ChatCompletionResponse,
@ -27,18 +26,6 @@ from .config import MetaReferenceImplConfig
from .model_parallel import LlamaModelParallelGenerator
async def get_provider_impl(
config: MetaReferenceImplConfig, _deps: Dict[Api, ProviderSpec]
):
assert isinstance(
config, MetaReferenceImplConfig
), f"Unexpected config type: {type(config)}"
impl = MetaReferenceInferenceImpl(config)
await impl.initialize()
return impl
# there's a single model parallel process running serving the model. for now,
# we don't support multiple concurrent requests to this process.
SEMAPHORE = asyncio.Semaphore(1)

View file

@ -27,7 +27,7 @@ def available_inference_providers() -> List[ProviderSpec]:
module="llama_toolchain.inference.meta_reference",
config_class="llama_toolchain.inference.meta_reference.MetaReferenceImplConfig",
),
adapter_provider_spec(
remote_provider_spec(
api=Api.inference,
adapter=AdapterSpec(
adapter_id="ollama",

View file

@ -6,7 +6,7 @@
import asyncio
from typing import Dict, List, Optional
from typing import Any, Dict, List, Optional
import fire
import httpx
@ -16,7 +16,7 @@ from llama_toolchain.distribution.datatypes import RemoteProviderConfig
from .api import * # noqa: F403
async def get_provider_impl(config: RemoteProviderConfig) -> Memory:
async def get_client_impl(config: RemoteProviderConfig, _deps: Any) -> Memory:
return MemoryClient(config.url)

View file

@ -4,5 +4,16 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from .config import FaissImplConfig # noqa
from .faiss import get_provider_impl # noqa
from .config import FaissImplConfig
async def get_provider_impl(config: FaissImplConfig, _deps):
from .faiss import FaissMemoryImpl
assert isinstance(
config, FaissImplConfig
), f"Unexpected config type: {type(config)}"
impl = FaissMemoryImpl(config)
await impl.initialize()
return impl

View file

@ -15,21 +15,10 @@ import numpy as np
from llama_models.llama3.api.datatypes import * # noqa: F403
from llama_models.llama3.api.tokenizer import Tokenizer
from llama_toolchain.distribution.datatypes import Api, ProviderSpec
from llama_toolchain.memory.api import * # noqa: F403
from .config import FaissImplConfig
async def get_provider_impl(config: FaissImplConfig, _deps: Dict[Api, ProviderSpec]):
assert isinstance(
config, FaissImplConfig
), f"Unexpected config type: {type(config)}"
impl = FaissMemoryImpl(config)
await impl.initialize()
return impl
async def content_from_doc(doc: MemoryBankDocument) -> str:
if isinstance(doc.content, URL):
async with httpx.AsyncClient() as client:

View file

@ -6,11 +6,12 @@
import asyncio
from typing import Any
import fire
import httpx
from llama_models.llama3.api.datatypes import UserMessage
from pydantic import BaseModel
from termcolor import cprint
@ -19,7 +20,7 @@ from llama_toolchain.distribution.datatypes import RemoteProviderConfig
from .api import * # noqa: F403
async def get_provider_impl(config: RemoteProviderConfig) -> Safety:
async def get_client_impl(config: RemoteProviderConfig, _deps: Any) -> Safety:
return SafetyClient(config.url)

View file

@ -4,5 +4,14 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from .config import SafetyConfig # noqa
from .safety import get_provider_impl # noqa
from .config import SafetyConfig
async def get_provider_impl(config: SafetyConfig, _deps):
from .safety import MetaReferenceSafetyImpl
assert isinstance(config, SafetyConfig), f"Unexpected config type: {type(config)}"
impl = MetaReferenceSafetyImpl(config)
await impl.initialize()
return impl

View file

@ -5,12 +5,10 @@
# the root directory of this source tree.
import asyncio
from typing import Dict
from llama_models.sku_list import resolve_model
from llama_toolchain.common.model_utils import model_local_dir
from llama_toolchain.distribution.datatypes import Api, ProviderSpec
from llama_toolchain.safety.api import * # noqa
from .config import SafetyConfig
@ -25,14 +23,6 @@ from .shields import (
)
async def get_provider_impl(config: SafetyConfig, _deps: Dict[Api, ProviderSpec]):
assert isinstance(config, SafetyConfig), f"Unexpected config type: {type(config)}"
impl = MetaReferenceSafetyImpl(config)
await impl.initialize()
return impl
def resolve_and_get_path(model_name: str) -> str:
model = resolve_model(model_name)
assert model is not None, f"Could not resolve model {model_name}"