Allow setting environment variables from llama stack run and fix ollama

This commit is contained in:
Ashwin Bharambe 2024-11-17 19:33:48 -08:00
parent a061f3f8c1
commit b1d119466e
19 changed files with 129 additions and 55 deletions

View file

@ -217,15 +217,23 @@ class StackBuild(Subcommand):
provider_types = [provider_types]
for i, provider_type in enumerate(provider_types):
p_spec = Provider(
provider_id=f"{provider_type}-{i}",
provider_type=provider_type,
config={},
)
pid = provider_type.split("::")[-1]
config_type = instantiate_class_type(
provider_registry[Api(api)][provider_type].config_class
)
p_spec.config = config_type()
if hasattr(config_type, "sample_run_config"):
config = config_type.sample_run_config(
__distro_dir__=f"distributions/{build_config.name}"
)
else:
config = {}
p_spec = Provider(
provider_id=f"{pid}-{i}" if len(provider_types) > 1 else pid,
provider_type=provider_type,
config=config,
)
run_config.providers[api].append(p_spec)
os.makedirs(build_dir, exist_ok=True)

View file

@ -39,6 +39,13 @@ class StackRun(Subcommand):
help="Disable IPv6 support",
default=False,
)
self.parser.add_argument(
"--env",
action="append",
help="Environment variables to pass to the server in KEY=VALUE format. Can be specified multiple times.",
default=[],
metavar="KEY=VALUE",
)
def _run_stack_run_cmd(self, args: argparse.Namespace) -> None:
from pathlib import Path
@ -97,4 +104,16 @@ class StackRun(Subcommand):
if args.disable_ipv6:
run_args.append("--disable-ipv6")
for env_var in args.env:
if "=" not in env_var:
self.parser.error(
f"Environment variable '{env_var}' must be in KEY=VALUE format"
)
return
key, value = env_var.split("=", 1) # split on first = only
if not key:
self.parser.error(f"Environment variable '{env_var}' has empty key")
return
run_args.extend(["--env", f"{key}={value}"])
run_with_pty(run_args)