initial branch commit

This commit is contained in:
dltn 2024-11-05 14:46:11 -08:00
parent 37b330b4ef
commit cb1c093faa
3 changed files with 110 additions and 0 deletions

View file

@ -0,0 +1,5 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.

View file

@ -0,0 +1,89 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import yaml
from llama_stack.providers.datatypes import * # noqa: F403
from llama_stack.distribution.datatypes import * # noqa: F403
from llama_stack.distribution.configure import parse_and_maybe_upgrade_config
from llama_stack.distribution.distribution import get_provider_registry
from llama_stack.distribution.resolver import resolve_impls
class LlamaStackInline:
def __init__(self, run_config_path: str):
self.run_config_path = run_config_path
self.impls = {}
self.run_config = None
def print_pip_command(self):
# TODO: de-dupe this with build.py
all_providers = get_provider_registry()
deps = []
for (
api_str,
provider_or_providers,
) in self.run_config.providers.items():
providers_for_api = all_providers[Api(api_str)]
providers = (
provider_or_providers
if isinstance(provider_or_providers, list)
else [provider_or_providers]
)
for provider in providers:
if provider.provider_id not in providers_for_api:
raise ValueError(
f"Provider `{provider}` is not available for API `{api_str}`"
)
provider_spec = providers_for_api[provider.provider_id]
deps.extend(provider_spec.pip_packages)
if provider_spec.docker_image:
raise ValueError(
"A stack's dependencies cannot have a docker image"
)
normal_deps = []
special_deps = []
for package in deps:
if "--no-deps" in package or "--index-url" in package:
special_deps.append(package)
else:
normal_deps.append(package)
deps = list(set(deps))
special_deps = list(set(special_deps))
print(
f"Please install needed dependencies using the following commands:\n\n\tpip install {' '.join(normal_deps)}"
)
for special_dep in special_deps:
print(f"\tpip install {special_dep}")
print()
async def initialize(self):
with open(self.run_config_path, "r") as f:
config_dict = yaml.safe_load(f)
self.run_config = parse_and_maybe_upgrade_config(config_dict)
all_providers = get_provider_registry()
try:
impls = await resolve_impls(self.run_config, all_providers)
self.impls = impls
except ModuleNotFoundError as e:
print(str(e))
self.print_pip_command()
if "provider_data" in config_dict:
provider_id = chosen[api.value][0].provider_id
provider_data = config_dict["provider_data"].get(provider_id, {})
if provider_data:
set_request_provider_data(
{"X-LlamaStack-ProviderData": json.dumps(provider_data)}
)

View file

@ -0,0 +1,16 @@
from inline import LlamaStackInline
from llama_stack.apis.inference.inference import Inference
from llama_stack.providers.datatypes import * # noqa: F403
async def main():
inline = LlamaStackInline("/home/dalton/.llama/builds/conda/nov5-run.yaml")
await inline.initialize()
print(inline.impls)
# Run the main function
import asyncio
asyncio.run(main())