getting closer to a distro definition, distro install + configure works

This commit is contained in:
Ashwin Bharambe 2024-08-01 22:59:11 -07:00
parent dac2b5a1ed
commit 041cafbee3
11 changed files with 471 additions and 130 deletions

View file

@ -5,18 +5,24 @@
# the root directory of this source tree. # the root directory of this source tree.
import argparse import argparse
import os import importlib
import inspect
import shlex
from pathlib import Path from pathlib import Path
from typing import Annotated, get_args, get_origin, Literal, Union
import pkg_resources import yaml
from pydantic import BaseModel
from termcolor import cprint
from llama_toolchain.cli.subcommand import Subcommand from llama_toolchain.cli.subcommand import Subcommand
from llama_toolchain.distribution.registry import all_registered_distributions from llama_toolchain.distribution.datatypes import Distribution, PassthroughApiAdapter
from llama_toolchain.utils import LLAMA_STACK_CONFIG_DIR from llama_toolchain.distribution.registry import available_distributions
from llama_toolchain.utils import DISTRIBS_BASE_DIR
from .utils import run_command
DISTRIBS = available_distributions()
CONFIGS_BASE_DIR = os.path.join(LLAMA_STACK_CONFIG_DIR, "configs")
class DistributionConfigure(Subcommand): class DistributionConfigure(Subcommand):
@ -34,59 +40,198 @@ class DistributionConfigure(Subcommand):
self.parser.set_defaults(func=self._run_distribution_configure_cmd) self.parser.set_defaults(func=self._run_distribution_configure_cmd)
def _add_arguments(self): def _add_arguments(self):
distribs = all_registered_distributions()
self.parser.add_argument( self.parser.add_argument(
"--name", "--name",
type=str, type=str,
help="Mame of the distribution to configure", help="Mame of the distribution to configure",
default="local-source", default="local-source",
choices=[d.name for d in distribs], choices=[d.name for d in available_distributions()],
) )
def read_user_inputs(self):
checkpoint_dir = input(
"Enter the checkpoint directory for the model (e.g., ~/.llama/checkpoints/Meta-Llama-3-8B/): "
)
model_parallel_size = input(
"Enter model parallel size (e.g., 1 for 8B / 8 for 70B and 405B): "
)
assert model_parallel_size.isdigit() and int(model_parallel_size) in {
1,
8,
}, "model parallel size must be 1 or 8"
return checkpoint_dir, model_parallel_size
def write_output_yaml(self, checkpoint_dir, model_parallel_size, yaml_output_path):
default_conf_path = pkg_resources.resource_filename(
"llama_toolchain", "data/default_distribution_config.yaml"
)
with open(default_conf_path, "r") as f:
yaml_content = f.read()
yaml_content = yaml_content.format(
checkpoint_dir=checkpoint_dir,
model_parallel_size=model_parallel_size,
)
with open(yaml_output_path, "w") as yaml_file:
yaml_file.write(yaml_content.strip())
print(f"YAML configuration has been written to {yaml_output_path}")
def _run_distribution_configure_cmd(self, args: argparse.Namespace) -> None: def _run_distribution_configure_cmd(self, args: argparse.Namespace) -> None:
checkpoint_dir, model_parallel_size = self.read_user_inputs() dist = None
checkpoint_dir = os.path.expanduser(checkpoint_dir) for d in DISTRIBS:
if d.name == args.name:
dist = d
break
assert ( if dist is None:
Path(checkpoint_dir).exists() and Path(checkpoint_dir).is_dir() self.parser.error(f"Could not find distribution {args.name}")
), f"{checkpoint_dir} does not exist or it not a directory" return
os.makedirs(CONFIGS_BASE_DIR, exist_ok=True) env_file = DISTRIBS_BASE_DIR / dist.name / "conda.env"
yaml_output_path = Path(CONFIGS_BASE_DIR) / "distribution.yaml" # read this file to get the conda env name
assert env_file.exists(), f"Could not find conda env file {env_file}"
with open(env_file, "r") as f:
conda_env = f.read().strip()
self.write_output_yaml( configure_llama_distribution(dist, conda_env)
checkpoint_dir,
model_parallel_size,
yaml_output_path, def configure_llama_distribution(dist: Distribution, conda_env: str):
python_exe = run_command(shlex.split("which python"))
# simple check
if conda_env not in python_exe:
raise ValueError(
f"Please re-run configure by activating the `{conda_env}` conda environment"
) )
adapter_configs = {}
for api_surface, adapter in dist.adapters.items():
if isinstance(adapter, PassthroughApiAdapter):
adapter_configs[api_surface.value] = adapter.dict()
else:
cprint(
f"Configuring API surface: {api_surface.value}", "white", attrs=["bold"]
)
config_type = instantiate_class_type(adapter.config_class)
# TODO: when we are re-configuring, we should read existing values
config = prompt_for_config(config_type)
adapter_configs[api_surface.value] = config.dict()
dist_config = {
"adapters": adapter_configs,
"conda_env": conda_env,
}
yaml_output_path = Path(DISTRIBS_BASE_DIR) / dist.name / "config.yaml"
with open(yaml_output_path, "w") as fp:
fp.write(yaml.dump(dist_config, sort_keys=False))
print(f"YAML configuration has been written to {yaml_output_path}")
def instantiate_class_type(fully_qualified_name):
module_name, class_name = fully_qualified_name.rsplit(".", 1)
module = importlib.import_module(module_name)
return getattr(module, class_name)
def get_literal_values(field):
"""Extract literal values from a field if it's a Literal type."""
if get_origin(field.annotation) is Literal:
return get_args(field.annotation)
return None
def is_optional(field_type):
"""Check if a field type is Optional."""
return get_origin(field_type) is Union and type(None) in get_args(field_type)
def get_non_none_type(field_type):
"""Get the non-None type from an Optional type."""
return next(arg for arg in get_args(field_type) if arg is not type(None))
def prompt_for_config(config_type: type[BaseModel]) -> BaseModel:
"""
Recursively prompt the user for configuration values based on a Pydantic BaseModel.
Args:
config_type: A Pydantic BaseModel class representing the configuration structure.
Returns:
An instance of the config_type with user-provided values.
"""
config_data = {}
for field_name, field in config_type.__fields__.items():
field_type = field.annotation
default_value = (
field.default if not isinstance(field.default, type(Ellipsis)) else None
)
is_required = field.required
# Skip fields with Literal type
if get_origin(field_type) is Literal:
continue
# Check if the field is a discriminated union
if get_origin(field_type) is Annotated:
inner_type = get_args(field_type)[0]
if get_origin(inner_type) is Union:
discriminator = field.field_info.discriminator
if discriminator:
union_types = get_args(inner_type)
# Find the discriminator field in each union type
type_map = {}
for t in union_types:
disc_field = t.__fields__[discriminator]
literal_values = get_literal_values(disc_field)
if literal_values:
for value in literal_values:
type_map[value] = t
while True:
discriminator_value = input(
f"Enter the {discriminator} (options: {', '.join(type_map.keys())}): "
)
if discriminator_value in type_map:
chosen_type = type_map[discriminator_value]
print(f"\nConfiguring {chosen_type.__name__}:")
sub_config = prompt_for_config(chosen_type)
config_data[field_name] = sub_config
# Set the discriminator field in the sub-config
setattr(sub_config, discriminator, discriminator_value)
break
else:
print(f"Invalid {discriminator}. Please try again.")
continue
if inspect.isclass(field_type) and issubclass(field_type, BaseModel):
print(f"\nEntering sub-configuration for {field_name}:")
config_data[field_name] = prompt_for_config(field_type)
else:
prompt = f"Enter value for {field_name}"
if default_value is not None:
prompt += f" (default: {default_value})"
if is_optional(field_type):
prompt += " (optional)"
elif is_required:
prompt += " (required)"
prompt += ": "
while True:
user_input = input(prompt)
if user_input == "":
if default_value is not None:
config_data[field_name] = default_value
break
elif is_optional(field_type):
config_data[field_name] = None
break
elif not is_required:
config_data[field_name] = None
break
else:
print("This field is required. Please provide a value.")
continue
try:
# Handle Optional types
if is_optional(field_type):
if user_input.lower() == "none":
config_data[field_name] = None
break
field_type = get_non_none_type(field_type)
# Convert the input to the correct type
if inspect.isclass(field_type) and issubclass(
field_type, BaseModel
):
# For nested BaseModels, we assume a dictionary-like string input
import ast
config_data[field_name] = field_type(
**ast.literal_eval(user_input)
)
else:
config_data[field_name] = field_type(user_input)
break
except ValueError:
print(
f"Invalid input. Expected type: {getattr(field_type, '__name__', str(field_type))}"
)
return config_type(**config_data)

View file

@ -7,6 +7,7 @@
import argparse import argparse
from llama_toolchain.cli.subcommand import Subcommand from llama_toolchain.cli.subcommand import Subcommand
from .configure import DistributionConfigure from .configure import DistributionConfigure
from .create import DistributionCreate from .create import DistributionCreate
from .install import DistributionInstall from .install import DistributionInstall

View file

@ -7,20 +7,16 @@
import argparse import argparse
import os import os
import shlex import shlex
import subprocess
from pathlib import Path
import pkg_resources import pkg_resources
from llama_toolchain.cli.subcommand import Subcommand from llama_toolchain.cli.subcommand import Subcommand
from llama_toolchain.distribution.registry import all_registered_distributions from llama_toolchain.distribution.datatypes import distribution_dependencies
from llama_toolchain.utils import LLAMA_STACK_CONFIG_DIR from llama_toolchain.distribution.registry import available_distributions
from llama_toolchain.utils import DISTRIBS_BASE_DIR
from .utils import run_command, run_with_pty
DISTRIBS = available_distributions()
DISTRIBS_BASE_DIR = Path(LLAMA_STACK_CONFIG_DIR) / "distributions"
DISTRIBS = all_registered_distributions()
class DistributionInstall(Subcommand): class DistributionInstall(Subcommand):
@ -70,13 +66,19 @@ class DistributionInstall(Subcommand):
return return
os.makedirs(DISTRIBS_BASE_DIR / dist.name, exist_ok=True) os.makedirs(DISTRIBS_BASE_DIR / dist.name, exist_ok=True)
run_shell_script(script, args.conda_env, " ".join(dist.pip_packages))
deps = distribution_dependencies(dist)
run_command([script, args.conda_env, " ".join(deps)])
with open(DISTRIBS_BASE_DIR / dist.name / "conda.env", "w") as f: with open(DISTRIBS_BASE_DIR / dist.name / "conda.env", "w") as f:
f.write(f"{args.conda_env}\n") f.write(f"{args.conda_env}\n")
# we need to run configure _within_ the conda environment and need to run with
def run_shell_script(script_path, *args): # a pty since configure is
command_string = f"{script_path} {' '.join(shlex.quote(str(arg)) for arg in args)}" python_exe = run_command(
command_list = shlex.split(command_string) shlex.split(f"conda run -n {args.conda_env} which python")
print(f"Running command: {command_list}") ).strip()
subprocess.run(command_list, check=True, text=True) run_with_pty(
shlex.split(
f"{python_exe} -m llama_toolchain.cli.llama distribution configure --name {dist.name}"
)
)

View file

@ -9,7 +9,8 @@ import argparse
from llama_toolchain.cli.subcommand import Subcommand from llama_toolchain.cli.subcommand import Subcommand
from llama_toolchain.cli.table import print_table from llama_toolchain.cli.table import print_table
from llama_toolchain.distribution.registry import all_registered_distributions from llama_toolchain.distribution.datatypes import distribution_dependencies
from llama_toolchain.distribution.registry import available_distributions
class DistributionList(Subcommand): class DistributionList(Subcommand):
@ -37,12 +38,13 @@ class DistributionList(Subcommand):
] ]
rows = [] rows = []
for dist in all_registered_distributions(): for dist in available_distributions():
deps = distribution_dependencies(dist)
rows.append( rows.append(
[ [
dist.name, dist.name,
dist.description, dist.description,
", ".join(dist.pip_packages), ", ".join(deps),
] ]
) )
print_table( print_table(

View file

@ -0,0 +1,76 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import errno
import os
import pty
import select
import subprocess
import sys
import termios
import tty
def run_with_pty(command):
old_settings = termios.tcgetattr(sys.stdin)
# Create a new pseudo-terminal
master, slave = pty.openpty()
try:
# ensure the terminal does not echo input
tty.setraw(sys.stdin.fileno())
process = subprocess.Popen(
command,
stdin=slave,
stdout=slave,
stderr=slave,
universal_newlines=True,
)
# Close the slave file descriptor as it's now owned by the subprocess
os.close(slave)
def handle_io():
while True:
rlist, _, _ = select.select([sys.stdin, master], [], [])
if sys.stdin in rlist:
data = os.read(sys.stdin.fileno(), 1024)
if not data: # EOF
break
os.write(master, data)
if master in rlist:
data = os.read(master, 1024)
if not data:
break
os.write(sys.stdout.fileno(), data)
handle_io()
except (EOFError, KeyboardInterrupt):
pass
except OSError as e:
if e.errno != errno.EIO:
raise
finally:
# Restore original terminal settings
termios.tcsetattr(sys.stdin, termios.TCSADRAIN, old_settings)
process.wait()
os.close(master)
return process.returncode
def run_command(command):
process = subprocess.Popen(command, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
output, error = process.communicate()
if process.returncode != 0:
print(f"Error: {error.decode('utf-8')}")
sys.exit(1)
return output.decode("utf-8")

View file

@ -32,15 +32,7 @@ class PassthroughApiAdapterConfig(BaseModel):
@json_schema_type @json_schema_type
class PythonImplAdapterConfig(BaseModel): class PythonImplAdapterConfig(BaseModel):
type: Literal[AdapterType.python_impl.value] = AdapterType.python_impl.value type: Literal[AdapterType.python_impl.value] = AdapterType.python_impl.value
pip_packages: List[str] = Field( adapter_id: str
default_factory=list,
description="The pip dependencies needed for this implementation",
)
module: str = Field(..., description="The name of the module to import")
entrypoint: str = Field(
...,
description="The name of the entrypoint function which creates the implementation for the API",
)
kwargs: Dict[str, Any] = Field( kwargs: Dict[str, Any] = Field(
default_factory=dict, description="kwargs to pass to the entrypoint" default_factory=dict, description="kwargs to pass to the entrypoint"
) )
@ -63,21 +55,43 @@ AdapterConfig = Annotated[
] ]
class DistributionConfig(BaseModel): @json_schema_type
inference: AdapterConfig class ApiSurface(Enum):
safety: AdapterConfig inference = "inference"
safety = "safety"
# configs for each API that the stack provides, e.g.
# agentic_system: AdapterConfig
# post_training: AdapterConfig
class DistributionConfigDefaults(BaseModel): @json_schema_type
inference: Dict[str, Any] = Field( class Adapter(BaseModel):
default_factory=dict, description="Default kwargs for the inference adapter" api_surface: ApiSurface
adapter_id: str
@json_schema_type
class SourceAdapter(Adapter):
pip_packages: List[str] = Field(
default_factory=list,
description="The pip dependencies needed for this implementation",
) )
safety: Dict[str, Any] = Field( module: str = Field(
default_factory=dict, description="Default kwargs for the safety adapter" ...,
description="""
Fully-qualified name of the module to import. The module is expected to have
a `get_adapter_instance()` method which will be passed a validated config object
of type `config_class`.""",
)
config_class: str = Field(
...,
description="Fully-qualified classname of the config for this adapter",
)
@json_schema_type
class PassthroughApiAdapter(Adapter):
base_url: str = Field(..., description="The base URL for the llama stack provider")
headers: Dict[str, str] = Field(
default_factory=dict,
description="Headers (e.g., authorization) to send with the request",
) )
@ -85,9 +99,22 @@ class Distribution(BaseModel):
name: str name: str
description: str description: str
# you must install the packages to get the functionality needed. adapters: Dict[ApiSurface, Adapter] = Field(
# later, we may have a docker image be the main artifact of default_factory=dict,
# a distribution. description="The API surfaces provided by this distribution",
pip_packages: List[str] )
config_defaults: DistributionConfigDefaults additional_pip_packages: List[str] = Field(
default_factory=list,
description="Additional pip packages beyond those required by the adapters",
)
def distribution_dependencies(distribution: Distribution) -> List[str]:
# only consider SourceAdapters when calculating dependencies
return [
dep
for adapter in distribution.adapters.values()
if isinstance(adapter, SourceAdapter)
for dep in adapter.pip_packages
] + distribution.additional_pip_packages

View file

@ -6,30 +6,60 @@
from typing import List from typing import List
from .datatypes import Distribution, DistributionConfigDefaults from llama_toolchain.inference.adapters import available_inference_adapters
from .datatypes import ApiSurface, Distribution
# This is currently duplicated from `requirements.txt` with a few minor changes
# dev-dependencies like "ufmt" etc. are nuked. A few specialized dependencies
# are moved to the appropriate distribution.
COMMON_DEPENDENCIES = [
"accelerate",
"black==24.4.2",
"blobfile",
"codeshield",
"fairscale",
"fastapi",
"fire",
"flake8",
"httpx",
"huggingface-hub",
"hydra-core",
"hydra-zen",
"json-strong-typing",
"llama-models",
"omegaconf",
"pandas",
"Pillow",
"pydantic==1.10.13",
"pydantic_core==2.18.2",
"python-openapi",
"requests",
"tiktoken",
"torch",
"transformers",
"uvicorn",
]
def all_registered_distributions() -> List[Distribution]: def available_distributions() -> List[Distribution]:
inference_adapters_by_id = {a.adapter_id: a for a in available_inference_adapters()}
return [ return [
Distribution( Distribution(
name="local-source", name="local-source",
description="Use code from `llama_toolchain` itself to serve all llama stack APIs", description="Use code from `llama_toolchain` itself to serve all llama stack APIs",
pip_packages=[], additional_pip_packages=COMMON_DEPENDENCIES,
config_defaults=DistributionConfigDefaults( adapters={
inference={ ApiSurface.inference: inference_adapters_by_id["meta-reference"],
"max_seq_len": 4096, },
"max_batch_size": 1,
},
safety={},
),
), ),
Distribution( Distribution(
name="local-ollama", name="local-ollama",
description="Like local-source, but use ollama for running LLM inference", description="Like local-source, but use ollama for running LLM inference",
pip_packages=["ollama"], additional_pip_packages=COMMON_DEPENDENCIES,
config_defaults=DistributionConfigDefaults( adapters={
inference={}, ApiSurface.inference: inference_adapters_by_id["meta-ollama"],
safety={}, },
),
), ),
] ]

View file

@ -0,0 +1,33 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import List
from llama_toolchain.distribution.datatypes import Adapter, ApiSurface, SourceAdapter
def available_inference_adapters() -> List[Adapter]:
return [
SourceAdapter(
api_surface=ApiSurface.inference,
adapter_id="meta-reference",
pip_packages=[
"torch",
"zmq",
],
module="llama_toolchain.inference.inference",
config_class="llama_toolchain.inference.inference.InlineImplConfig",
),
SourceAdapter(
api_surface=ApiSurface.inference,
adapter_id="meta-ollama",
pip_packages=[
"ollama",
],
module="llama_toolchain.inference.ollama",
config_class="llama_toolchain.inference.ollama.OllamaImplConfig",
),
]

View file

@ -16,8 +16,8 @@ from .api.datatypes import (
ToolCallParseStatus, ToolCallParseStatus,
) )
from .api.endpoints import ( from .api.endpoints import (
ChatCompletionResponse,
ChatCompletionRequest, ChatCompletionRequest,
ChatCompletionResponse,
ChatCompletionResponseStreamChunk, ChatCompletionResponseStreamChunk,
CompletionRequest, CompletionRequest,
Inference, Inference,
@ -25,6 +25,13 @@ from .api.endpoints import (
from .model_parallel import LlamaModelParallelGenerator from .model_parallel import LlamaModelParallelGenerator
def get_adapter_impl(config: InlineImplConfig) -> Inference:
assert isinstance(
config, InlineImplConfig
), f"Unexpected config type: {type(config)}"
return InferenceImpl(config)
class InferenceImpl(Inference): class InferenceImpl(Inference):
def __init__(self, config: InlineImplConfig) -> None: def __init__(self, config: InlineImplConfig) -> None:

View file

@ -1,9 +1,14 @@
import httpx # Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import uuid import uuid
from typing import AsyncGenerator from typing import AsyncGenerator
from ollama import AsyncClient import httpx
from llama_models.llama3_1.api.datatypes import ( from llama_models.llama3_1.api.datatypes import (
BuiltinTool, BuiltinTool,
@ -14,6 +19,8 @@ from llama_models.llama3_1.api.datatypes import (
) )
from llama_models.llama3_1.api.tool_utils import ToolUtils from llama_models.llama3_1.api.tool_utils import ToolUtils
from ollama import AsyncClient
from .api.config import OllamaImplConfig from .api.config import OllamaImplConfig
from .api.datatypes import ( from .api.datatypes import (
ChatCompletionResponseEvent, ChatCompletionResponseEvent,
@ -22,14 +29,20 @@ from .api.datatypes import (
ToolCallParseStatus, ToolCallParseStatus,
) )
from .api.endpoints import ( from .api.endpoints import (
ChatCompletionResponse,
ChatCompletionRequest, ChatCompletionRequest,
ChatCompletionResponse,
ChatCompletionResponseStreamChunk, ChatCompletionResponseStreamChunk,
CompletionRequest, CompletionRequest,
Inference, Inference,
) )
def get_adapter_impl(config: OllamaImplConfig) -> Inference:
assert isinstance(
config, OllamaImplConfig
), f"Unexpected config type: {type(config)}"
return OllamaInference(config)
class OllamaInference(Inference): class OllamaInference(Inference):
@ -41,9 +54,13 @@ class OllamaInference(Inference):
self.client = AsyncClient(host=self.config.url) self.client = AsyncClient(host=self.config.url)
try: try:
status = await self.client.pull(self.model) status = await self.client.pull(self.model)
assert status['status'] == 'success', f"Failed to pull model {self.model} in ollama" assert (
status["status"] == "success"
), f"Failed to pull model {self.model} in ollama"
except httpx.ConnectError: except httpx.ConnectError:
print("Ollama Server is not running, start it using `ollama serve` in a separate terminal") print(
"Ollama Server is not running, start it using `ollama serve` in a separate terminal"
)
raise raise
async def shutdown(self) -> None: async def shutdown(self) -> None:
@ -55,9 +72,7 @@ class OllamaInference(Inference):
def _messages_to_ollama_messages(self, messages: list[Message]) -> list: def _messages_to_ollama_messages(self, messages: list[Message]) -> list:
ollama_messages = [] ollama_messages = []
for message in messages: for message in messages:
ollama_messages.append( ollama_messages.append({"role": message.role, "content": message.content})
{"role": message.role, "content": message.content}
)
return ollama_messages return ollama_messages
@ -67,16 +82,16 @@ class OllamaInference(Inference):
model=self.model, model=self.model,
messages=self._messages_to_ollama_messages(request.messages), messages=self._messages_to_ollama_messages(request.messages),
stream=False, stream=False,
#TODO: add support for options like temp, top_p, max_seq_length, etc # TODO: add support for options like temp, top_p, max_seq_length, etc
) )
if r['done']: if r["done"]:
if r['done_reason'] == 'stop': if r["done_reason"] == "stop":
stop_reason = StopReason.end_of_turn stop_reason = StopReason.end_of_turn
elif r['done_reason'] == 'length': elif r["done_reason"] == "length":
stop_reason = StopReason.out_of_tokens stop_reason = StopReason.out_of_tokens
completion_message = decode_assistant_message_from_content( completion_message = decode_assistant_message_from_content(
r['message']['content'], r["message"]["content"],
stop_reason, stop_reason,
) )
yield ChatCompletionResponse( yield ChatCompletionResponse(
@ -94,7 +109,7 @@ class OllamaInference(Inference):
stream = await self.client.chat( stream = await self.client.chat(
model=self.model, model=self.model,
messages=self._messages_to_ollama_messages(request.messages), messages=self._messages_to_ollama_messages(request.messages),
stream=True stream=True,
) )
buffer = "" buffer = ""
@ -103,14 +118,14 @@ class OllamaInference(Inference):
async for chunk in stream: async for chunk in stream:
# check if ollama is done # check if ollama is done
if chunk['done']: if chunk["done"]:
if chunk['done_reason'] == 'stop': if chunk["done_reason"] == "stop":
stop_reason = StopReason.end_of_turn stop_reason = StopReason.end_of_turn
elif chunk['done_reason'] == 'length': elif chunk["done_reason"] == "length":
stop_reason = StopReason.out_of_tokens stop_reason = StopReason.out_of_tokens
break break
text = chunk['message']['content'] text = chunk["message"]["content"]
# check if its a tool call ( aka starts with <|python_tag|> ) # check if its a tool call ( aka starts with <|python_tag|> )
if not ipython and text.startswith("<|python_tag|>"): if not ipython and text.startswith("<|python_tag|>"):
@ -197,7 +212,7 @@ class OllamaInference(Inference):
) )
#TODO: Consolidate this with impl in llama-models # TODO: Consolidate this with impl in llama-models
def decode_assistant_message_from_content( def decode_assistant_message_from_content(
content: str, content: str,
stop_reason: StopReason, stop_reason: StopReason,

View file

@ -6,6 +6,7 @@
import getpass import getpass
import os import os
from pathlib import Path
from typing import Optional from typing import Optional
from hydra import compose, initialize, MissingConfigException from hydra import compose, initialize, MissingConfigException
@ -16,6 +17,8 @@ from omegaconf import OmegaConf
LLAMA_STACK_CONFIG_DIR = os.path.expanduser("~/.llama/") LLAMA_STACK_CONFIG_DIR = os.path.expanduser("~/.llama/")
DISTRIBS_BASE_DIR = Path(LLAMA_STACK_CONFIG_DIR) / "distributions"
def get_root_directory(): def get_root_directory():
current_dir = os.path.dirname(os.path.abspath(__file__)) current_dir = os.path.dirname(os.path.abspath(__file__))