mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-07-29 15:23:51 +00:00
Several smaller fixes to make adapters work
Also, reorganized the pattern of __init__ inside providers so configuration can stay lightweight
This commit is contained in:
parent
2a1552a5eb
commit
45987996c4
23 changed files with 164 additions and 160 deletions
|
@ -194,11 +194,6 @@ async def run_rag(host: str, port: int):
|
||||||
MemoryToolDefinition(
|
MemoryToolDefinition(
|
||||||
max_tokens_in_context=2048,
|
max_tokens_in_context=2048,
|
||||||
memory_bank_configs=[],
|
memory_bank_configs=[],
|
||||||
# memory_bank_configs=[
|
|
||||||
# AgenticSystemVectorMemoryBankConfig(
|
|
||||||
# bank_id="970b8790-268e-4fd3-a9b1-d0e597e975ed",
|
|
||||||
# )
|
|
||||||
# ],
|
|
||||||
),
|
),
|
||||||
]
|
]
|
||||||
|
|
||||||
|
@ -210,8 +205,9 @@ async def run_rag(host: str, port: int):
|
||||||
await _run_agent(api, tool_definitions, user_prompts, attachments)
|
await _run_agent(api, tool_definitions, user_prompts, attachments)
|
||||||
|
|
||||||
|
|
||||||
def main(host: str, port: int):
|
def main(host: str, port: int, rag: bool = False):
|
||||||
asyncio.run(run_rag(host, port))
|
fn = run_rag if rag else run_main
|
||||||
|
asyncio.run(fn(host, port))
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|
|
@ -4,5 +4,27 @@
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
from .agentic_system import get_provider_impl # noqa
|
from typing import Dict
|
||||||
from .config import MetaReferenceImplConfig # noqa
|
|
||||||
|
from llama_toolchain.distribution.datatypes import Api, ProviderSpec
|
||||||
|
|
||||||
|
from .config import MetaReferenceImplConfig
|
||||||
|
|
||||||
|
|
||||||
|
async def get_provider_impl(
|
||||||
|
config: MetaReferenceImplConfig, deps: Dict[Api, ProviderSpec]
|
||||||
|
):
|
||||||
|
from .agentic_system import MetaReferenceAgenticSystemImpl
|
||||||
|
|
||||||
|
assert isinstance(
|
||||||
|
config, MetaReferenceImplConfig
|
||||||
|
), f"Unexpected config type: {type(config)}"
|
||||||
|
|
||||||
|
impl = MetaReferenceAgenticSystemImpl(
|
||||||
|
config,
|
||||||
|
deps[Api.inference],
|
||||||
|
deps[Api.memory],
|
||||||
|
deps[Api.safety],
|
||||||
|
)
|
||||||
|
await impl.initialize()
|
||||||
|
return impl
|
||||||
|
|
|
@ -8,9 +8,8 @@
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
import uuid
|
import uuid
|
||||||
from typing import AsyncGenerator, Dict
|
from typing import AsyncGenerator
|
||||||
|
|
||||||
from llama_toolchain.distribution.datatypes import Api, ProviderSpec
|
|
||||||
from llama_toolchain.inference.api import Inference
|
from llama_toolchain.inference.api import Inference
|
||||||
from llama_toolchain.memory.api import Memory
|
from llama_toolchain.memory.api import Memory
|
||||||
from llama_toolchain.safety.api import Safety
|
from llama_toolchain.safety.api import Safety
|
||||||
|
@ -31,23 +30,6 @@ logger = logging.getLogger()
|
||||||
logger.setLevel(logging.INFO)
|
logger.setLevel(logging.INFO)
|
||||||
|
|
||||||
|
|
||||||
async def get_provider_impl(
|
|
||||||
config: MetaReferenceImplConfig, deps: Dict[Api, ProviderSpec]
|
|
||||||
):
|
|
||||||
assert isinstance(
|
|
||||||
config, MetaReferenceImplConfig
|
|
||||||
), f"Unexpected config type: {type(config)}"
|
|
||||||
|
|
||||||
impl = MetaReferenceAgenticSystemImpl(
|
|
||||||
config,
|
|
||||||
deps[Api.inference],
|
|
||||||
deps[Api.memory],
|
|
||||||
deps[Api.safety],
|
|
||||||
)
|
|
||||||
await impl.initialize()
|
|
||||||
return impl
|
|
||||||
|
|
||||||
|
|
||||||
AGENT_INSTANCES_BY_ID = {}
|
AGENT_INSTANCES_BY_ID = {}
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -10,7 +10,7 @@ import os
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from typing import Dict, List, Optional, Tuple
|
from typing import Dict, List, Optional
|
||||||
|
|
||||||
import pkg_resources
|
import pkg_resources
|
||||||
import yaml
|
import yaml
|
||||||
|
@ -37,26 +37,17 @@ def get_dependencies(
|
||||||
) -> Dependencies:
|
) -> Dependencies:
|
||||||
from llama_toolchain.distribution.distribution import SERVER_DEPENDENCIES
|
from llama_toolchain.distribution.distribution import SERVER_DEPENDENCIES
|
||||||
|
|
||||||
def _deps(provider: ProviderSpec) -> Tuple[List[str], Optional[str]]:
|
pip_packages = provider.pip_packages
|
||||||
if isinstance(provider, InlineProviderSpec):
|
|
||||||
return provider.pip_packages, provider.docker_image
|
|
||||||
else:
|
|
||||||
if provider.adapter:
|
|
||||||
return provider.adapter.pip_packages, None
|
|
||||||
return [], None
|
|
||||||
|
|
||||||
pip_packages, docker_image = _deps(provider)
|
|
||||||
for dep in dependencies.values():
|
for dep in dependencies.values():
|
||||||
dep_pip_packages, dep_docker_image = _deps(dep)
|
if dep.docker_image:
|
||||||
if docker_image and dep_docker_image:
|
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"You can only have the root provider specify a docker image"
|
"You can only have the root provider specify a docker image"
|
||||||
)
|
)
|
||||||
|
pip_packages.extend(dep.pip_packages)
|
||||||
pip_packages.extend(dep_pip_packages)
|
|
||||||
|
|
||||||
return Dependencies(
|
return Dependencies(
|
||||||
docker_image=docker_image, pip_packages=pip_packages + SERVER_DEPENDENCIES
|
docker_image=provider.docker_image,
|
||||||
|
pip_packages=pip_packages + SERVER_DEPENDENCIES
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@ -158,6 +149,7 @@ class ApiBuild(Subcommand):
|
||||||
build_dir = BUILDS_BASE_DIR / args.api
|
build_dir = BUILDS_BASE_DIR / args.api
|
||||||
os.makedirs(build_dir, exist_ok=True)
|
os.makedirs(build_dir, exist_ok=True)
|
||||||
|
|
||||||
|
# get these names straight. too confusing.
|
||||||
provider_deps = parse_dependencies(args.dependencies or "", self.parser)
|
provider_deps = parse_dependencies(args.dependencies or "", self.parser)
|
||||||
dependencies = get_dependencies(providers[args.provider], provider_deps)
|
dependencies = get_dependencies(providers[args.provider], provider_deps)
|
||||||
|
|
||||||
|
@ -167,7 +159,7 @@ class ApiBuild(Subcommand):
|
||||||
api.value: {
|
api.value: {
|
||||||
"provider_id": args.provider,
|
"provider_id": args.provider,
|
||||||
},
|
},
|
||||||
**{k: {"provider_id": v} for k, v in provider_deps.items()},
|
**provider_deps,
|
||||||
}
|
}
|
||||||
with open(package_file, "w") as f:
|
with open(package_file, "w") as f:
|
||||||
c = PackageConfig(
|
c = PackageConfig(
|
||||||
|
|
|
@ -48,7 +48,10 @@ class ApiConfigure(Subcommand):
|
||||||
)
|
)
|
||||||
|
|
||||||
def _run_api_configure_cmd(self, args: argparse.Namespace) -> None:
|
def _run_api_configure_cmd(self, args: argparse.Namespace) -> None:
|
||||||
config_file = BUILDS_BASE_DIR / args.api / f"{args.name}.yaml"
|
name = args.name
|
||||||
|
if not name.endswith(".yaml"):
|
||||||
|
name += ".yaml"
|
||||||
|
config_file = BUILDS_BASE_DIR / args.api / name
|
||||||
if not config_file.exists():
|
if not config_file.exists():
|
||||||
self.parser.error(
|
self.parser.error(
|
||||||
f"Could not find {config_file}. Please run `llama api build` first"
|
f"Could not find {config_file}. Please run `llama api build` first"
|
||||||
|
@ -79,10 +82,19 @@ def configure_llama_provider(config_file: Path) -> None:
|
||||||
)
|
)
|
||||||
|
|
||||||
provider_spec = providers[provider_id]
|
provider_spec = providers[provider_id]
|
||||||
cprint(f"Configuring API surface: {api}", "white", attrs=["bold"])
|
cprint(
|
||||||
|
f"Configuring API surface: {api} ({provider_id})", "white", attrs=["bold"]
|
||||||
|
)
|
||||||
config_type = instantiate_class_type(provider_spec.config_class)
|
config_type = instantiate_class_type(provider_spec.config_class)
|
||||||
|
|
||||||
|
try:
|
||||||
|
existing_provider_config = config_type(**stub_config)
|
||||||
|
except KeyError:
|
||||||
|
existing_provider_config = None
|
||||||
|
|
||||||
provider_config = prompt_for_config(
|
provider_config = prompt_for_config(
|
||||||
config_type,
|
config_type,
|
||||||
|
existing_provider_config,
|
||||||
)
|
)
|
||||||
print("")
|
print("")
|
||||||
|
|
||||||
|
|
|
@ -29,10 +29,9 @@ class ApiStart(Subcommand):
|
||||||
|
|
||||||
def _add_arguments(self):
|
def _add_arguments(self):
|
||||||
self.parser.add_argument(
|
self.parser.add_argument(
|
||||||
"--yaml-config",
|
"yaml_config",
|
||||||
type=str,
|
type=str,
|
||||||
help="Yaml config containing the API build configuration",
|
help="Yaml config containing the API build configuration",
|
||||||
required=True,
|
|
||||||
)
|
)
|
||||||
self.parser.add_argument(
|
self.parser.add_argument(
|
||||||
"--port",
|
"--port",
|
||||||
|
|
|
@ -69,7 +69,7 @@ ensure_conda_env_python310() {
|
||||||
conda create -n "${env_name}" python="${python_version}" -y
|
conda create -n "${env_name}" python="${python_version}" -y
|
||||||
|
|
||||||
ENVNAME="${env_name}"
|
ENVNAME="${env_name}"
|
||||||
setup_cleanup_handlers
|
# setup_cleanup_handlers
|
||||||
fi
|
fi
|
||||||
|
|
||||||
eval "$(conda shell.bash hook)"
|
eval "$(conda shell.bash hook)"
|
||||||
|
|
|
@ -36,16 +36,14 @@ class ProviderSpec(BaseModel):
|
||||||
...,
|
...,
|
||||||
description="Fully-qualified classname of the config for this provider",
|
description="Fully-qualified classname of the config for this provider",
|
||||||
)
|
)
|
||||||
|
api_dependencies: List[Api] = Field(
|
||||||
|
default_factory=list,
|
||||||
|
description="Higher-level API surfaces may depend on other providers to provide their functionality",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
class AdapterSpec(BaseModel):
|
class AdapterSpec(BaseModel):
|
||||||
"""
|
|
||||||
If some code is needed to convert the remote responses into Llama Stack compatible
|
|
||||||
API responses, specify the adapter here. If not specified, it indicates the remote
|
|
||||||
as being "Llama Stack compatible"
|
|
||||||
"""
|
|
||||||
|
|
||||||
adapter_id: str = Field(
|
adapter_id: str = Field(
|
||||||
...,
|
...,
|
||||||
description="Unique identifier for this adapter",
|
description="Unique identifier for this adapter",
|
||||||
|
@ -89,11 +87,6 @@ Fully-qualified name of the module to import. The module is expected to have:
|
||||||
- `get_provider_impl(config, deps)`: returns the local implementation
|
- `get_provider_impl(config, deps)`: returns the local implementation
|
||||||
""",
|
""",
|
||||||
)
|
)
|
||||||
api_dependencies: List[Api] = Field(
|
|
||||||
default_factory=list,
|
|
||||||
description="Higher-level API surfaces may depend on other providers to provide their functionality",
|
|
||||||
)
|
|
||||||
is_adapter: bool = False
|
|
||||||
|
|
||||||
|
|
||||||
class RemoteProviderConfig(BaseModel):
|
class RemoteProviderConfig(BaseModel):
|
||||||
|
@ -113,34 +106,41 @@ def remote_provider_id(adapter_id: str) -> str:
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
class RemoteProviderSpec(ProviderSpec):
|
class RemoteProviderSpec(ProviderSpec):
|
||||||
provider_id: str = "remote"
|
adapter: Optional[AdapterSpec] = Field(
|
||||||
config_class: str = "llama_toolchain.distribution.datatypes.RemoteProviderConfig"
|
default=None,
|
||||||
|
description="""
|
||||||
|
If some code is needed to convert the remote responses into Llama Stack compatible
|
||||||
|
API responses, specify the adapter here. If not specified, it indicates the remote
|
||||||
|
as being "Llama Stack compatible"
|
||||||
|
""",
|
||||||
|
)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def module(self) -> str:
|
def module(self) -> str:
|
||||||
|
if self.adapter:
|
||||||
|
return self.adapter.module
|
||||||
return f"llama_toolchain.{self.api.value}.client"
|
return f"llama_toolchain.{self.api.value}.client"
|
||||||
|
|
||||||
|
@property
|
||||||
def remote_provider_spec(api: Api) -> RemoteProviderSpec:
|
def pip_packages(self) -> List[str]:
|
||||||
return RemoteProviderSpec(api=api)
|
if self.adapter:
|
||||||
|
return self.adapter.pip_packages
|
||||||
|
return []
|
||||||
|
|
||||||
|
|
||||||
# TODO: use computed_field to avoid this wrapper
|
# Can avoid this by using Pydantic computed_field
|
||||||
# the @computed_field decorator
|
def remote_provider_spec(
|
||||||
def adapter_provider_spec(api: Api, adapter: AdapterSpec) -> InlineProviderSpec:
|
api: Api, adapter: Optional[AdapterSpec] = None
|
||||||
|
) -> RemoteProviderSpec:
|
||||||
config_class = (
|
config_class = (
|
||||||
adapter.config_class
|
adapter.config_class
|
||||||
if adapter.config_class
|
if adapter and adapter.config_class
|
||||||
else "llama_toolchain.distribution.datatypes.RemoteProviderConfig"
|
else "llama_toolchain.distribution.datatypes.RemoteProviderConfig"
|
||||||
)
|
)
|
||||||
|
provider_id = remote_provider_id(adapter.adapter_id) if adapter else "remote"
|
||||||
|
|
||||||
return InlineProviderSpec(
|
return RemoteProviderSpec(
|
||||||
api=api,
|
api=api, provider_id=provider_id, config_class=config_class, adapter=adapter
|
||||||
provider_id=remote_provider_id(adapter.adapter_id),
|
|
||||||
pip_packages=adapter.pip_packages,
|
|
||||||
module=adapter.module,
|
|
||||||
config_class=config_class,
|
|
||||||
is_adapter=True,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -22,6 +22,7 @@ from .datatypes import (
|
||||||
DistributionSpec,
|
DistributionSpec,
|
||||||
InlineProviderSpec,
|
InlineProviderSpec,
|
||||||
ProviderSpec,
|
ProviderSpec,
|
||||||
|
remote_provider_spec,
|
||||||
)
|
)
|
||||||
|
|
||||||
# These are the dependencies needed by the distribution server.
|
# These are the dependencies needed by the distribution server.
|
||||||
|
@ -89,9 +90,12 @@ def api_providers() -> Dict[Api, Dict[str, ProviderSpec]]:
|
||||||
a.provider_id: a for a in available_agentic_system_providers()
|
a.provider_id: a for a in available_agentic_system_providers()
|
||||||
}
|
}
|
||||||
|
|
||||||
return {
|
ret = {
|
||||||
Api.inference: inference_providers_by_id,
|
Api.inference: inference_providers_by_id,
|
||||||
Api.safety: safety_providers_by_id,
|
Api.safety: safety_providers_by_id,
|
||||||
Api.agentic_system: agentic_system_providers_by_id,
|
Api.agentic_system: agentic_system_providers_by_id,
|
||||||
Api.memory: {a.provider_id: a for a in available_memory_providers()},
|
Api.memory: {a.provider_id: a for a in available_memory_providers()},
|
||||||
}
|
}
|
||||||
|
for k, v in ret.items():
|
||||||
|
v["remote"] = remote_provider_spec(k)
|
||||||
|
return ret
|
||||||
|
|
|
@ -8,7 +8,7 @@ import asyncio
|
||||||
import importlib
|
import importlib
|
||||||
from typing import Any, Dict
|
from typing import Any, Dict
|
||||||
|
|
||||||
from .datatypes import InlineProviderSpec, ProviderSpec, RemoteProviderConfig
|
from .datatypes import ProviderSpec, RemoteProviderConfig, RemoteProviderSpec
|
||||||
|
|
||||||
|
|
||||||
def instantiate_class_type(fully_qualified_name):
|
def instantiate_class_type(fully_qualified_name):
|
||||||
|
@ -26,16 +26,21 @@ def instantiate_provider(
|
||||||
module = importlib.import_module(provider_spec.module)
|
module = importlib.import_module(provider_spec.module)
|
||||||
|
|
||||||
config_type = instantiate_class_type(provider_spec.config_class)
|
config_type = instantiate_class_type(provider_spec.config_class)
|
||||||
if isinstance(provider_spec, InlineProviderSpec):
|
if isinstance(provider_spec, RemoteProviderSpec):
|
||||||
if provider_spec.is_adapter:
|
if provider_spec.adapter:
|
||||||
if not issubclass(config_type, RemoteProviderConfig):
|
if not issubclass(config_type, RemoteProviderConfig):
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"Config class {provider_spec.config_class} does not inherit from RemoteProviderConfig"
|
f"Config class {provider_spec.config_class} does not inherit from RemoteProviderConfig"
|
||||||
)
|
)
|
||||||
config = config_type(**provider_config)
|
method = "get_adapter_impl"
|
||||||
|
else:
|
||||||
if isinstance(provider_spec, InlineProviderSpec):
|
method = "get_client_impl"
|
||||||
args = [config, deps]
|
|
||||||
else:
|
else:
|
||||||
args = [config]
|
method = "get_provider_impl"
|
||||||
return asyncio.run(module.get_provider_impl(*args))
|
|
||||||
|
config = config_type(**provider_config)
|
||||||
|
fn = getattr(module, method)
|
||||||
|
impl = asyncio.run(fn(config, deps))
|
||||||
|
impl.__provider_spec__ = provider_spec
|
||||||
|
impl.__provider_config__ = config
|
||||||
|
return impl
|
||||||
|
|
|
@ -38,7 +38,7 @@ from pydantic import BaseModel, ValidationError
|
||||||
from termcolor import cprint
|
from termcolor import cprint
|
||||||
from typing_extensions import Annotated
|
from typing_extensions import Annotated
|
||||||
|
|
||||||
from .datatypes import Api, ProviderSpec, RemoteProviderSpec
|
from .datatypes import Api, InlineProviderSpec, ProviderSpec, RemoteProviderSpec
|
||||||
from .distribution import api_endpoints, api_providers
|
from .distribution import api_endpoints, api_providers
|
||||||
from .dynamic import instantiate_provider
|
from .dynamic import instantiate_provider
|
||||||
|
|
||||||
|
@ -230,10 +230,9 @@ def topological_sort(providers: List[ProviderSpec]) -> List[ProviderSpec]:
|
||||||
def dfs(a: ProviderSpec, visited: Set[Api], stack: List[Api]):
|
def dfs(a: ProviderSpec, visited: Set[Api], stack: List[Api]):
|
||||||
visited.add(a.api)
|
visited.add(a.api)
|
||||||
|
|
||||||
if not isinstance(a, RemoteProviderSpec):
|
for api in a.api_dependencies:
|
||||||
for api in a.api_dependencies:
|
if api not in visited:
|
||||||
if api not in visited:
|
dfs(by_id[api], visited, stack)
|
||||||
dfs(by_id[api], visited, stack)
|
|
||||||
|
|
||||||
stack.append(a.api)
|
stack.append(a.api)
|
||||||
|
|
||||||
|
@ -261,7 +260,10 @@ def resolve_impls(
|
||||||
f"Could not find provider_spec config for {api}. Please add it to the config"
|
f"Could not find provider_spec config for {api}. Please add it to the config"
|
||||||
)
|
)
|
||||||
|
|
||||||
deps = {api: impls[api] for api in provider_spec.api_dependencies}
|
if isinstance(provider_spec, InlineProviderSpec):
|
||||||
|
deps = {api: impls[api] for api in provider_spec.api_dependencies}
|
||||||
|
else:
|
||||||
|
deps = {}
|
||||||
provider_config = provider_configs[api.value]
|
provider_config = provider_configs[api.value]
|
||||||
impl = instantiate_provider(provider_spec, provider_config, deps)
|
impl = instantiate_provider(provider_spec, provider_config, deps)
|
||||||
impls[api] = impl
|
impls[api] = impl
|
||||||
|
@ -302,7 +304,7 @@ def main(yaml_config: str, port: int = 5000, disable_ipv6: bool = False):
|
||||||
and provider_spec.adapter is None
|
and provider_spec.adapter is None
|
||||||
):
|
):
|
||||||
for endpoint in endpoints:
|
for endpoint in endpoints:
|
||||||
url = impl.base_url + endpoint.route
|
url = impl.__provider_config__.url
|
||||||
getattr(app, endpoint.method)(endpoint.route)(
|
getattr(app, endpoint.method)(endpoint.route)(
|
||||||
create_dynamic_passthrough(url)
|
create_dynamic_passthrough(url)
|
||||||
)
|
)
|
||||||
|
|
|
@ -4,4 +4,12 @@
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
from .ollama import get_provider_impl # noqa
|
from llama_toolchain.distribution.datatypes import RemoteProviderConfig
|
||||||
|
|
||||||
|
|
||||||
|
async def get_adapter_impl(config: RemoteProviderConfig, _deps):
|
||||||
|
from .ollama import OllamaInferenceAdapter
|
||||||
|
|
||||||
|
impl = OllamaInferenceAdapter(config.url)
|
||||||
|
await impl.initialize()
|
||||||
|
return impl
|
||||||
|
|
|
@ -4,7 +4,7 @@
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
from typing import Any, AsyncGenerator
|
from typing import AsyncGenerator
|
||||||
|
|
||||||
import httpx
|
import httpx
|
||||||
|
|
||||||
|
@ -14,34 +14,18 @@ from llama_models.llama3.api.tokenizer import Tokenizer
|
||||||
from llama_models.sku_list import resolve_model
|
from llama_models.sku_list import resolve_model
|
||||||
from ollama import AsyncClient
|
from ollama import AsyncClient
|
||||||
|
|
||||||
from llama_toolchain.distribution.datatypes import RemoteProviderConfig
|
from llama_toolchain.inference.api import * # noqa: F403
|
||||||
from llama_toolchain.inference.api import (
|
|
||||||
ChatCompletionRequest,
|
|
||||||
ChatCompletionResponse,
|
|
||||||
ChatCompletionResponseEvent,
|
|
||||||
ChatCompletionResponseEventType,
|
|
||||||
ChatCompletionResponseStreamChunk,
|
|
||||||
CompletionRequest,
|
|
||||||
Inference,
|
|
||||||
ToolCallDelta,
|
|
||||||
ToolCallParseStatus,
|
|
||||||
)
|
|
||||||
from llama_toolchain.inference.prepare_messages import prepare_messages
|
from llama_toolchain.inference.prepare_messages import prepare_messages
|
||||||
|
|
||||||
# TODO: Eventually this will move to the llama cli model list command
|
# TODO: Eventually this will move to the llama cli model list command
|
||||||
# mapping of Model SKUs to ollama models
|
# mapping of Model SKUs to ollama models
|
||||||
OLLAMA_SUPPORTED_SKUS = {
|
OLLAMA_SUPPORTED_SKUS = {
|
||||||
|
# "Meta-Llama3.1-8B-Instruct": "llama3.1",
|
||||||
"Meta-Llama3.1-8B-Instruct": "llama3.1:8b-instruct-fp16",
|
"Meta-Llama3.1-8B-Instruct": "llama3.1:8b-instruct-fp16",
|
||||||
"Meta-Llama3.1-70B-Instruct": "llama3.1:70b-instruct-fp16",
|
"Meta-Llama3.1-70B-Instruct": "llama3.1:70b-instruct-fp16",
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
async def get_provider_impl(config: RemoteProviderConfig, _deps: Any) -> Inference:
|
|
||||||
impl = OllamaInferenceAdapter(config.url)
|
|
||||||
await impl.initialize()
|
|
||||||
return impl
|
|
||||||
|
|
||||||
|
|
||||||
class OllamaInferenceAdapter(Inference):
|
class OllamaInferenceAdapter(Inference):
|
||||||
def __init__(self, url: str) -> None:
|
def __init__(self, url: str) -> None:
|
||||||
self.url = url
|
self.url = url
|
||||||
|
|
|
@ -6,7 +6,7 @@
|
||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
import json
|
import json
|
||||||
from typing import AsyncGenerator
|
from typing import Any, AsyncGenerator
|
||||||
|
|
||||||
import fire
|
import fire
|
||||||
import httpx
|
import httpx
|
||||||
|
@ -26,7 +26,7 @@ from .api import (
|
||||||
from .event_logger import EventLogger
|
from .event_logger import EventLogger
|
||||||
|
|
||||||
|
|
||||||
async def get_provider_impl(config: RemoteProviderConfig) -> Inference:
|
async def get_client_impl(config: RemoteProviderConfig, _deps: Any) -> Inference:
|
||||||
return InferenceClient(config.url)
|
return InferenceClient(config.url)
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -5,4 +5,15 @@
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
from .config import MetaReferenceImplConfig # noqa
|
from .config import MetaReferenceImplConfig # noqa
|
||||||
from .inference import get_provider_impl # noqa
|
|
||||||
|
|
||||||
|
async def get_provider_impl(config: MetaReferenceImplConfig, _deps):
|
||||||
|
from .inference import MetaReferenceInferenceImpl
|
||||||
|
|
||||||
|
assert isinstance(
|
||||||
|
config, MetaReferenceImplConfig
|
||||||
|
), f"Unexpected config type: {type(config)}"
|
||||||
|
|
||||||
|
impl = MetaReferenceInferenceImpl(config)
|
||||||
|
await impl.initialize()
|
||||||
|
return impl
|
||||||
|
|
|
@ -6,12 +6,11 @@
|
||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
|
|
||||||
from typing import AsyncIterator, Dict, Union
|
from typing import AsyncIterator, Union
|
||||||
|
|
||||||
from llama_models.llama3.api.datatypes import StopReason
|
from llama_models.llama3.api.datatypes import StopReason
|
||||||
from llama_models.sku_list import resolve_model
|
from llama_models.sku_list import resolve_model
|
||||||
|
|
||||||
from llama_toolchain.distribution.datatypes import Api, ProviderSpec
|
|
||||||
from llama_toolchain.inference.api import (
|
from llama_toolchain.inference.api import (
|
||||||
ChatCompletionRequest,
|
ChatCompletionRequest,
|
||||||
ChatCompletionResponse,
|
ChatCompletionResponse,
|
||||||
|
@ -27,18 +26,6 @@ from .config import MetaReferenceImplConfig
|
||||||
from .model_parallel import LlamaModelParallelGenerator
|
from .model_parallel import LlamaModelParallelGenerator
|
||||||
|
|
||||||
|
|
||||||
async def get_provider_impl(
|
|
||||||
config: MetaReferenceImplConfig, _deps: Dict[Api, ProviderSpec]
|
|
||||||
):
|
|
||||||
assert isinstance(
|
|
||||||
config, MetaReferenceImplConfig
|
|
||||||
), f"Unexpected config type: {type(config)}"
|
|
||||||
|
|
||||||
impl = MetaReferenceInferenceImpl(config)
|
|
||||||
await impl.initialize()
|
|
||||||
return impl
|
|
||||||
|
|
||||||
|
|
||||||
# there's a single model parallel process running serving the model. for now,
|
# there's a single model parallel process running serving the model. for now,
|
||||||
# we don't support multiple concurrent requests to this process.
|
# we don't support multiple concurrent requests to this process.
|
||||||
SEMAPHORE = asyncio.Semaphore(1)
|
SEMAPHORE = asyncio.Semaphore(1)
|
||||||
|
|
|
@ -27,7 +27,7 @@ def available_inference_providers() -> List[ProviderSpec]:
|
||||||
module="llama_toolchain.inference.meta_reference",
|
module="llama_toolchain.inference.meta_reference",
|
||||||
config_class="llama_toolchain.inference.meta_reference.MetaReferenceImplConfig",
|
config_class="llama_toolchain.inference.meta_reference.MetaReferenceImplConfig",
|
||||||
),
|
),
|
||||||
adapter_provider_spec(
|
remote_provider_spec(
|
||||||
api=Api.inference,
|
api=Api.inference,
|
||||||
adapter=AdapterSpec(
|
adapter=AdapterSpec(
|
||||||
adapter_id="ollama",
|
adapter_id="ollama",
|
||||||
|
|
|
@ -6,7 +6,7 @@
|
||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
|
|
||||||
from typing import Dict, List, Optional
|
from typing import Any, Dict, List, Optional
|
||||||
|
|
||||||
import fire
|
import fire
|
||||||
import httpx
|
import httpx
|
||||||
|
@ -16,7 +16,7 @@ from llama_toolchain.distribution.datatypes import RemoteProviderConfig
|
||||||
from .api import * # noqa: F403
|
from .api import * # noqa: F403
|
||||||
|
|
||||||
|
|
||||||
async def get_provider_impl(config: RemoteProviderConfig) -> Memory:
|
async def get_client_impl(config: RemoteProviderConfig, _deps: Any) -> Memory:
|
||||||
return MemoryClient(config.url)
|
return MemoryClient(config.url)
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -4,5 +4,16 @@
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
from .config import FaissImplConfig # noqa
|
from .config import FaissImplConfig
|
||||||
from .faiss import get_provider_impl # noqa
|
|
||||||
|
|
||||||
|
async def get_provider_impl(config: FaissImplConfig, _deps):
|
||||||
|
from .faiss import FaissMemoryImpl
|
||||||
|
|
||||||
|
assert isinstance(
|
||||||
|
config, FaissImplConfig
|
||||||
|
), f"Unexpected config type: {type(config)}"
|
||||||
|
|
||||||
|
impl = FaissMemoryImpl(config)
|
||||||
|
await impl.initialize()
|
||||||
|
return impl
|
||||||
|
|
|
@ -15,21 +15,10 @@ import numpy as np
|
||||||
from llama_models.llama3.api.datatypes import * # noqa: F403
|
from llama_models.llama3.api.datatypes import * # noqa: F403
|
||||||
from llama_models.llama3.api.tokenizer import Tokenizer
|
from llama_models.llama3.api.tokenizer import Tokenizer
|
||||||
|
|
||||||
from llama_toolchain.distribution.datatypes import Api, ProviderSpec
|
|
||||||
from llama_toolchain.memory.api import * # noqa: F403
|
from llama_toolchain.memory.api import * # noqa: F403
|
||||||
from .config import FaissImplConfig
|
from .config import FaissImplConfig
|
||||||
|
|
||||||
|
|
||||||
async def get_provider_impl(config: FaissImplConfig, _deps: Dict[Api, ProviderSpec]):
|
|
||||||
assert isinstance(
|
|
||||||
config, FaissImplConfig
|
|
||||||
), f"Unexpected config type: {type(config)}"
|
|
||||||
|
|
||||||
impl = FaissMemoryImpl(config)
|
|
||||||
await impl.initialize()
|
|
||||||
return impl
|
|
||||||
|
|
||||||
|
|
||||||
async def content_from_doc(doc: MemoryBankDocument) -> str:
|
async def content_from_doc(doc: MemoryBankDocument) -> str:
|
||||||
if isinstance(doc.content, URL):
|
if isinstance(doc.content, URL):
|
||||||
async with httpx.AsyncClient() as client:
|
async with httpx.AsyncClient() as client:
|
||||||
|
|
|
@ -6,11 +6,12 @@
|
||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
|
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
import fire
|
import fire
|
||||||
import httpx
|
import httpx
|
||||||
|
|
||||||
from llama_models.llama3.api.datatypes import UserMessage
|
from llama_models.llama3.api.datatypes import UserMessage
|
||||||
|
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
from termcolor import cprint
|
from termcolor import cprint
|
||||||
|
|
||||||
|
@ -19,7 +20,7 @@ from llama_toolchain.distribution.datatypes import RemoteProviderConfig
|
||||||
from .api import * # noqa: F403
|
from .api import * # noqa: F403
|
||||||
|
|
||||||
|
|
||||||
async def get_provider_impl(config: RemoteProviderConfig) -> Safety:
|
async def get_client_impl(config: RemoteProviderConfig, _deps: Any) -> Safety:
|
||||||
return SafetyClient(config.url)
|
return SafetyClient(config.url)
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -4,5 +4,14 @@
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
from .config import SafetyConfig # noqa
|
from .config import SafetyConfig
|
||||||
from .safety import get_provider_impl # noqa
|
|
||||||
|
|
||||||
|
async def get_provider_impl(config: SafetyConfig, _deps):
|
||||||
|
from .safety import MetaReferenceSafetyImpl
|
||||||
|
|
||||||
|
assert isinstance(config, SafetyConfig), f"Unexpected config type: {type(config)}"
|
||||||
|
|
||||||
|
impl = MetaReferenceSafetyImpl(config)
|
||||||
|
await impl.initialize()
|
||||||
|
return impl
|
||||||
|
|
|
@ -5,12 +5,10 @@
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
from typing import Dict
|
|
||||||
|
|
||||||
from llama_models.sku_list import resolve_model
|
from llama_models.sku_list import resolve_model
|
||||||
|
|
||||||
from llama_toolchain.common.model_utils import model_local_dir
|
from llama_toolchain.common.model_utils import model_local_dir
|
||||||
from llama_toolchain.distribution.datatypes import Api, ProviderSpec
|
|
||||||
from llama_toolchain.safety.api import * # noqa
|
from llama_toolchain.safety.api import * # noqa
|
||||||
|
|
||||||
from .config import SafetyConfig
|
from .config import SafetyConfig
|
||||||
|
@ -25,14 +23,6 @@ from .shields import (
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
async def get_provider_impl(config: SafetyConfig, _deps: Dict[Api, ProviderSpec]):
|
|
||||||
assert isinstance(config, SafetyConfig), f"Unexpected config type: {type(config)}"
|
|
||||||
|
|
||||||
impl = MetaReferenceSafetyImpl(config)
|
|
||||||
await impl.initialize()
|
|
||||||
return impl
|
|
||||||
|
|
||||||
|
|
||||||
def resolve_and_get_path(model_name: str) -> str:
|
def resolve_and_get_path(model_name: str) -> str:
|
||||||
model = resolve_model(model_name)
|
model = resolve_model(model_name)
|
||||||
assert model is not None, f"Could not resolve model {model_name}"
|
assert model is not None, f"Could not resolve model {model_name}"
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue