diff --git a/llama_stack/inline/__init__.py b/llama_stack/inline/__init__.py new file mode 100644 index 000000000..756f351d8 --- /dev/null +++ b/llama_stack/inline/__init__.py @@ -0,0 +1,5 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. diff --git a/llama_stack/inline/inline.py b/llama_stack/inline/inline.py new file mode 100644 index 000000000..9800a6ce7 --- /dev/null +++ b/llama_stack/inline/inline.py @@ -0,0 +1,89 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +import yaml + +from llama_stack.providers.datatypes import * # noqa: F403 +from llama_stack.distribution.datatypes import * # noqa: F403 +from llama_stack.distribution.configure import parse_and_maybe_upgrade_config +from llama_stack.distribution.distribution import get_provider_registry +from llama_stack.distribution.resolver import resolve_impls + + +class LlamaStackInline: + def __init__(self, run_config_path: str): + self.run_config_path = run_config_path + self.impls = {} + self.run_config = None + + def print_pip_command(self): + # TODO: de-dupe this with build.py + all_providers = get_provider_registry() + deps = [] + for ( + api_str, + provider_or_providers, + ) in self.run_config.providers.items(): + providers_for_api = all_providers[Api(api_str)] + + providers = ( + provider_or_providers + if isinstance(provider_or_providers, list) + else [provider_or_providers] + ) + + for provider in providers: + if provider.provider_id not in providers_for_api: + raise ValueError( + f"Provider `{provider}` is not available for API `{api_str}`" + ) + + provider_spec = providers_for_api[provider.provider_id] + deps.extend(provider_spec.pip_packages) + if provider_spec.docker_image: + raise ValueError( + "A stack's dependencies cannot have a docker image" + ) + + normal_deps = [] + special_deps = [] + for package in deps: + if "--no-deps" in package or "--index-url" in package: + special_deps.append(package) + else: + normal_deps.append(package) + deps = list(set(deps)) + special_deps = list(set(special_deps)) + + print( + f"Please install needed dependencies using the following commands:\n\n\tpip install {' '.join(normal_deps)}" + ) + for special_dep in special_deps: + print(f"\tpip install {special_dep}") + print() + + async def initialize(self): + with open(self.run_config_path, "r") as f: + config_dict = yaml.safe_load(f) + + self.run_config = parse_and_maybe_upgrade_config(config_dict) + + all_providers = get_provider_registry() + + try: + impls = await resolve_impls(self.run_config, all_providers) + self.impls = impls + except ModuleNotFoundError as e: + print(str(e)) + self.print_pip_command() + + if "provider_data" in config_dict: + provider_id = chosen[api.value][0].provider_id + provider_data = config_dict["provider_data"].get(provider_id, {}) + if provider_data: + set_request_provider_data( + {"X-LlamaStack-ProviderData": json.dumps(provider_data)} + ) diff --git a/llama_stack/inline/test.py b/llama_stack/inline/test.py new file mode 100644 index 000000000..3f6823b5c --- /dev/null +++ b/llama_stack/inline/test.py @@ -0,0 +1,16 @@ +from inline import LlamaStackInline +from llama_stack.apis.inference.inference import Inference + +from llama_stack.providers.datatypes import * # noqa: F403 + + +async def main(): + inline = LlamaStackInline("/home/dalton/.llama/builds/conda/nov5-run.yaml") + await inline.initialize() + print(inline.impls) + + +# Run the main function +import asyncio + +asyncio.run(main())