Allow setting environment variables from llama stack run and fix ollama

This commit is contained in:
Ashwin Bharambe 2024-11-17 19:33:48 -08:00
parent a061f3f8c1
commit b1d119466e
19 changed files with 129 additions and 55 deletions

View file

@ -1,20 +1,20 @@
version: '2'
built_at: 2024-11-17 15:19:07.405618
built_at: 2024-11-17 19:33:00
image_name: ollama
docker_image: llamastack/distribution-ollama:test-0.0.52rc3
docker_image: null
conda_env: null
apis:
- telemetry
- agents
- memory
- inference
- agents
- safety
- inference
- telemetry
providers:
inference:
- provider_id: ollama
provider_type: remote::ollama
config:
port: ${env.OLLAMA_PORT}
url: ${env.OLLAMA_URL:http://localhost:11434}
memory:
- provider_id: faiss
provider_type: inline::faiss

View file

@ -1,20 +1,20 @@
version: '2'
built_at: 2024-11-17 15:19:07.395495
built_at: 2024-11-17 19:33:00
image_name: ollama
docker_image: llamastack/distribution-ollama:test-0.0.52rc3
docker_image: null
conda_env: null
apis:
- telemetry
- agents
- memory
- inference
- agents
- safety
- inference
- telemetry
providers:
inference:
- provider_id: ollama
provider_type: remote::ollama
config:
port: ${env.OLLAMA_PORT}
url: ${env.OLLAMA_URL:http://localhost:11434}
memory:
- provider_id: faiss
provider_type: inline::faiss

View file

@ -1,14 +1,14 @@
version: '2'
built_at: 2024-11-17 15:19:07.405727
built_at: 2024-11-17 19:33:00
image_name: remote-vllm
docker_image: llamastack/distribution-remote-vllm:test-0.0.52rc3
conda_env: null
apis:
- telemetry
- agents
- memory
- inference
- agents
- safety
- inference
- telemetry
providers:
inference:
- provider_id: vllm-inference

View file

@ -1,14 +1,14 @@
version: '2'
built_at: 2024-11-17 15:19:07.395327
built_at: 2024-11-17 19:33:00
image_name: remote-vllm
docker_image: llamastack/distribution-remote-vllm:test-0.0.52rc3
conda_env: null
apis:
- telemetry
- agents
- memory
- inference
- agents
- safety
- inference
- telemetry
providers:
inference:
- provider_id: vllm-inference

View file

@ -1,14 +1,14 @@
version: '2'
built_at: 2024-11-17 15:19:09.184709
built_at: 2024-11-17 19:33:00
image_name: tgi
docker_image: llamastack/distribution-tgi:test-0.0.52rc3
conda_env: null
apis:
- telemetry
- agents
- memory
- inference
- agents
- safety
- inference
- telemetry
providers:
inference:
- provider_id: tgi-inference

View file

@ -1,14 +1,14 @@
version: '2'
built_at: 2024-11-17 15:19:09.156305
built_at: 2024-11-17 19:33:00
image_name: tgi
docker_image: llamastack/distribution-tgi:test-0.0.52rc3
conda_env: null
apis:
- telemetry
- agents
- memory
- inference
- agents
- safety
- inference
- telemetry
providers:
inference:
- provider_id: tgi-inference

View file

@ -20,7 +20,7 @@ The following environment variables can be configured:
- `LLAMASTACK_PORT`: Port for the Llama Stack distribution server (default: `5001`)
- `INFERENCE_MODEL`: Inference model loaded into the TGI server (default: `meta-llama/Llama-3.2-3B-Instruct`)
- `OLLAMA_PORT`: Port of the Ollama server (default: `14343`)
- `OLLAMA_URL`: URL of the Ollama server (default: `http://host.docker.internal:11434`)
- `SAFETY_MODEL`: Name of the safety (Llama-Guard) model to use (default: `meta-llama/Llama-Guard-3-1B`)
### Models

View file

@ -217,15 +217,23 @@ class StackBuild(Subcommand):
provider_types = [provider_types]
for i, provider_type in enumerate(provider_types):
p_spec = Provider(
provider_id=f"{provider_type}-{i}",
provider_type=provider_type,
config={},
)
pid = provider_type.split("::")[-1]
config_type = instantiate_class_type(
provider_registry[Api(api)][provider_type].config_class
)
p_spec.config = config_type()
if hasattr(config_type, "sample_run_config"):
config = config_type.sample_run_config(
__distro_dir__=f"distributions/{build_config.name}"
)
else:
config = {}
p_spec = Provider(
provider_id=f"{pid}-{i}" if len(provider_types) > 1 else pid,
provider_type=provider_type,
config=config,
)
run_config.providers[api].append(p_spec)
os.makedirs(build_dir, exist_ok=True)

View file

@ -39,6 +39,13 @@ class StackRun(Subcommand):
help="Disable IPv6 support",
default=False,
)
self.parser.add_argument(
"--env",
action="append",
help="Environment variables to pass to the server in KEY=VALUE format. Can be specified multiple times.",
default=[],
metavar="KEY=VALUE",
)
def _run_stack_run_cmd(self, args: argparse.Namespace) -> None:
from pathlib import Path
@ -97,4 +104,16 @@ class StackRun(Subcommand):
if args.disable_ipv6:
run_args.append("--disable-ipv6")
for env_var in args.env:
if "=" not in env_var:
self.parser.error(
f"Environment variable '{env_var}' must be in KEY=VALUE format"
)
return
key, value = env_var.split("=", 1) # split on first = only
if not key:
self.parser.error(f"Environment variable '{env_var}' has empty key")
return
run_args.extend(["--env", f"{key}={value}"])
run_with_pty(run_args)

View file

@ -146,6 +146,8 @@ fi
# Set version tag based on PyPI version
if [ -n "$TEST_PYPI_VERSION" ]; then
version_tag="test-$TEST_PYPI_VERSION"
elif [[ -n "$LLAMA_STACK_DIR" || -n "$LLAMA_MODELS_DIR" ]]; then
version_tag="dev"
else
URL="https://pypi.org/pypi/llama-stack/json"
version_tag=$(curl -s $URL | jq -r '.info.version')

View file

@ -33,10 +33,33 @@ shift
port="$1"
shift
# Process environment variables from --env arguments
env_vars=""
while [[ $# -gt 0 ]]; do
case "$1" in
--env)
if [[ -n "$2" ]]; then
# collect environment variables so we can set them after activating the conda env
env_vars="$env_vars $2"
shift 2
else
echo -e "${RED}Error: --env requires a KEY=VALUE argument${NC}" >&2
exit 1
fi
;;
*)
shift
;;
esac
done
eval "$(conda shell.bash hook)"
conda deactivate && conda activate "$env_name"
$CONDA_PREFIX/bin/python \
set -x
$env_vars \
$CONDA_PREFIX/bin/python \
-m llama_stack.distribution.server.server \
--yaml_config "$yaml_config" \
--port "$port" "$@"

View file

@ -31,7 +31,7 @@ if [ $# -lt 3 ]; then
fi
build_name="$1"
docker_image="distribution-$build_name"
docker_image="localhost/distribution-$build_name"
shift
yaml_config="$1"
@ -40,6 +40,26 @@ shift
port="$1"
shift
# Process environment variables from --env arguments
env_vars=""
while [[ $# -gt 0 ]]; do
case "$1" in
--env)
echo "env = $2"
if [[ -n "$2" ]]; then
env_vars="$env_vars -e $2"
shift 2
else
echo -e "${RED}Error: --env requires a KEY=VALUE argument${NC}" >&2
exit 1
fi
;;
*)
shift
;;
esac
done
set -x
if command -v selinuxenabled &> /dev/null && selinuxenabled; then
@ -59,15 +79,18 @@ fi
version_tag="latest"
if [ -n "$PYPI_VERSION" ]; then
version_tag="$PYPI_VERSION"
elif [ -n "$LLAMA_STACK_DIR" ]; then
version_tag="dev"
elif [ -n "$TEST_PYPI_VERSION" ]; then
version_tag="test-$TEST_PYPI_VERSION"
fi
$DOCKER_BINARY run $DOCKER_OPTS -it \
-p $port:$port \
$env_vars \
-v "$yaml_config:/app/config.yaml" \
$mounts \
$docker_image:$version_tag \
python -m llama_stack.distribution.server.server \
--yaml_config /app/config.yaml \
--port $port "$@"
--port "$port"

View file

@ -6,17 +6,17 @@
from typing import Any, Dict
from llama_stack.distribution.datatypes import RemoteProviderConfig
from pydantic import BaseModel
DEFAULT_OLLAMA_PORT = 11434
DEFAULT_OLLAMA_URL = "http://localhost:11434"
class OllamaImplConfig(RemoteProviderConfig):
port: int
class OllamaImplConfig(BaseModel):
url: str = DEFAULT_OLLAMA_URL
@classmethod
def sample_run_config(
cls, port_str: str = str(DEFAULT_OLLAMA_PORT)
cls, url: str = "${env.OLLAMA_URL:http://localhost:11434}", **kwargs
) -> Dict[str, Any]:
return {"port": port_str}
return {"url": url}

View file

@ -82,7 +82,7 @@ class OllamaInferenceAdapter(Inference, ModelsProtocolPrivate):
return AsyncClient(host=self.url)
async def initialize(self) -> None:
print("Initializing Ollama, checking connectivity to server...")
print(f"checking connectivity to Ollama at `{self.url}`...")
try:
await self.client.ps()
except httpx.ConnectError as e:

View file

@ -21,7 +21,7 @@ class TGIImplConfig(BaseModel):
)
@classmethod
def sample_run_config(cls, url: str = "${env.TGI_URL}"):
def sample_run_config(cls, url: str = "${env.TGI_URL}", **kwargs):
return {
"url": url,
}

View file

@ -29,6 +29,7 @@ class VLLMInferenceAdapterConfig(BaseModel):
def sample_run_config(
cls,
url: str = "${env.VLLM_URL}",
**kwargs,
):
return {
"url": url,

View file

@ -2,7 +2,7 @@ version: '2'
name: ollama
distribution_spec:
description: Use (an external) Ollama server for running LLM inference
docker_image: llamastack/distribution-ollama:test-0.0.52rc3
docker_image: null
providers:
inference:
- remote::ollama

View file

@ -23,9 +23,7 @@ def get_distribution_template() -> DistributionTemplate:
inference_provider = Provider(
provider_id="ollama",
provider_type="remote::ollama",
config=OllamaImplConfig.sample_run_config(
port_str="${env.OLLAMA_PORT}",
),
config=OllamaImplConfig.sample_run_config(),
)
inference_model = ModelInput(
@ -41,7 +39,7 @@ def get_distribution_template() -> DistributionTemplate:
name="ollama",
distro_type="self_hosted",
description="Use (an external) Ollama server for running LLM inference",
docker_image="llamastack/distribution-ollama:test-0.0.52rc3",
docker_image=None,
template_path=Path(__file__).parent / "doc_template.md",
providers=providers,
default_models=[inference_model, safety_model],
@ -74,9 +72,9 @@ def get_distribution_template() -> DistributionTemplate:
"meta-llama/Llama-3.2-3B-Instruct",
"Inference model loaded into the TGI server",
),
"OLLAMA_PORT": (
"14343",
"Port of the Ollama server",
"OLLAMA_URL": (
"http://host.docker.internal:11434",
"URL of the Ollama server",
),
"SAFETY_MODEL": (
"meta-llama/Llama-Guard-3-1B",

View file

@ -87,7 +87,7 @@ class RunConfigSettings(BaseModel):
return StackRunConfig(
image_name=name,
docker_image=docker_image,
built_at=datetime.now(),
built_at=datetime.now().strftime("%Y-%m-%d %H:%M"),
apis=list(apis),
providers=provider_configs,
metadata_store=SqliteKVStoreConfig.sample_run_config(