mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-06-27 18:50:41 +00:00
use logging instead of prints (#499)
# What does this PR do? This PR moves all print statements to use logging. Things changed: - Had to add `await start_trace("sse_generator")` to server.py to actually get tracing working. else was not seeing any logs - If no telemetry provider is provided in the run.yaml, we will write to stdout - by default, the logs are going to be in JSON, but we expose an option to configure to output in a human readable way.
This commit is contained in:
parent
4e1105e563
commit
6395dadc2b
36 changed files with 234 additions and 163 deletions
|
@ -14,15 +14,19 @@ import httpx
|
||||||
from dotenv import load_dotenv
|
from dotenv import load_dotenv
|
||||||
|
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
from termcolor import cprint
|
|
||||||
|
|
||||||
from llama_models.llama3.api.datatypes import * # noqa: F403
|
from llama_models.llama3.api.datatypes import * # noqa: F403
|
||||||
from llama_stack.distribution.datatypes import RemoteProviderConfig
|
from llama_stack.distribution.datatypes import RemoteProviderConfig
|
||||||
|
|
||||||
from .agents import * # noqa: F403
|
from .agents import * # noqa: F403
|
||||||
|
import logging
|
||||||
|
|
||||||
from .event_logger import EventLogger
|
from .event_logger import EventLogger
|
||||||
|
|
||||||
|
|
||||||
|
log = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
load_dotenv()
|
load_dotenv()
|
||||||
|
|
||||||
|
|
||||||
|
@ -93,13 +97,12 @@ class AgentsClient(Agents):
|
||||||
try:
|
try:
|
||||||
jdata = json.loads(data)
|
jdata = json.loads(data)
|
||||||
if "error" in jdata:
|
if "error" in jdata:
|
||||||
cprint(data, "red")
|
log.error(data)
|
||||||
continue
|
continue
|
||||||
|
|
||||||
yield AgentTurnResponseStreamChunk(**jdata)
|
yield AgentTurnResponseStreamChunk(**jdata)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(data)
|
log.error(f"Error with parsing or validation: {e}")
|
||||||
print(f"Error with parsing or validation: {e}")
|
|
||||||
|
|
||||||
async def _nonstream_agent_turn(self, request: AgentTurnCreateRequest):
|
async def _nonstream_agent_turn(self, request: AgentTurnCreateRequest):
|
||||||
raise NotImplementedError("Non-streaming not implemented yet")
|
raise NotImplementedError("Non-streaming not implemented yet")
|
||||||
|
@ -125,7 +128,7 @@ async def _run_agent(
|
||||||
)
|
)
|
||||||
|
|
||||||
for content in user_prompts:
|
for content in user_prompts:
|
||||||
cprint(f"User> {content}", color="white", attrs=["bold"])
|
log.info(f"User> {content}", color="white", attrs=["bold"])
|
||||||
iterator = await api.create_agent_turn(
|
iterator = await api.create_agent_turn(
|
||||||
AgentTurnCreateRequest(
|
AgentTurnCreateRequest(
|
||||||
agent_id=create_response.agent_id,
|
agent_id=create_response.agent_id,
|
||||||
|
@ -138,9 +141,9 @@ async def _run_agent(
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
async for event, log in EventLogger().log(iterator):
|
async for event, logger in EventLogger().log(iterator):
|
||||||
if log is not None:
|
if logger is not None:
|
||||||
log.print()
|
log.info(logger)
|
||||||
|
|
||||||
|
|
||||||
async def run_llama_3_1(host: str, port: int, model: str = "Llama3.1-8B-Instruct"):
|
async def run_llama_3_1(host: str, port: int, model: str = "Llama3.1-8B-Instruct"):
|
||||||
|
|
|
@ -4,14 +4,13 @@
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
import logging
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from typing import List
|
from typing import List
|
||||||
|
|
||||||
import pkg_resources
|
import pkg_resources
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
|
||||||
from termcolor import cprint
|
|
||||||
|
|
||||||
from llama_stack.distribution.utils.exec import run_with_pty
|
from llama_stack.distribution.utils.exec import run_with_pty
|
||||||
|
|
||||||
from llama_stack.distribution.datatypes import * # noqa: F403
|
from llama_stack.distribution.datatypes import * # noqa: F403
|
||||||
|
@ -22,6 +21,8 @@ from llama_stack.distribution.distribution import get_provider_registry
|
||||||
from llama_stack.distribution.utils.config_dirs import BUILDS_BASE_DIR
|
from llama_stack.distribution.utils.config_dirs import BUILDS_BASE_DIR
|
||||||
|
|
||||||
|
|
||||||
|
log = logging.getLogger(__name__)
|
||||||
|
|
||||||
# These are the dependencies needed by the distribution server.
|
# These are the dependencies needed by the distribution server.
|
||||||
# `llama-stack` is automatically installed by the installation script.
|
# `llama-stack` is automatically installed by the installation script.
|
||||||
SERVER_DEPENDENCIES = [
|
SERVER_DEPENDENCIES = [
|
||||||
|
@ -89,12 +90,12 @@ def get_provider_dependencies(
|
||||||
def print_pip_install_help(providers: Dict[str, List[Provider]]):
|
def print_pip_install_help(providers: Dict[str, List[Provider]]):
|
||||||
normal_deps, special_deps = get_provider_dependencies(providers)
|
normal_deps, special_deps = get_provider_dependencies(providers)
|
||||||
|
|
||||||
print(
|
log.info(
|
||||||
f"Please install needed dependencies using the following commands:\n\n\tpip install {' '.join(normal_deps)}"
|
f"Please install needed dependencies using the following commands:\n\n\tpip install {' '.join(normal_deps)}"
|
||||||
)
|
)
|
||||||
for special_dep in special_deps:
|
for special_dep in special_deps:
|
||||||
print(f"\tpip install {special_dep}")
|
log.info(f"\tpip install {special_dep}")
|
||||||
print()
|
log.info()
|
||||||
|
|
||||||
|
|
||||||
def build_image(build_config: BuildConfig, build_file_path: Path):
|
def build_image(build_config: BuildConfig, build_file_path: Path):
|
||||||
|
@ -133,9 +134,8 @@ def build_image(build_config: BuildConfig, build_file_path: Path):
|
||||||
|
|
||||||
return_code = run_with_pty(args)
|
return_code = run_with_pty(args)
|
||||||
if return_code != 0:
|
if return_code != 0:
|
||||||
cprint(
|
log.error(
|
||||||
f"Failed to build target {build_config.name} with return code {return_code}",
|
f"Failed to build target {build_config.name} with return code {return_code}",
|
||||||
color="red",
|
|
||||||
)
|
)
|
||||||
|
|
||||||
return return_code
|
return return_code
|
||||||
|
|
|
@ -3,12 +3,12 @@
|
||||||
#
|
#
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
import logging
|
||||||
import textwrap
|
import textwrap
|
||||||
|
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
from llama_stack.distribution.datatypes import * # noqa: F403
|
from llama_stack.distribution.datatypes import * # noqa: F403
|
||||||
from termcolor import cprint
|
|
||||||
|
|
||||||
from llama_stack.distribution.distribution import (
|
from llama_stack.distribution.distribution import (
|
||||||
builtin_automatically_routed_apis,
|
builtin_automatically_routed_apis,
|
||||||
|
@ -22,6 +22,8 @@ from llama_stack.apis.models import * # noqa: F403
|
||||||
from llama_stack.apis.shields import * # noqa: F403
|
from llama_stack.apis.shields import * # noqa: F403
|
||||||
from llama_stack.apis.memory_banks import * # noqa: F403
|
from llama_stack.apis.memory_banks import * # noqa: F403
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
def configure_single_provider(
|
def configure_single_provider(
|
||||||
registry: Dict[str, ProviderSpec], provider: Provider
|
registry: Dict[str, ProviderSpec], provider: Provider
|
||||||
|
@ -50,7 +52,7 @@ def configure_api_providers(
|
||||||
is_nux = len(config.providers) == 0
|
is_nux = len(config.providers) == 0
|
||||||
|
|
||||||
if is_nux:
|
if is_nux:
|
||||||
print(
|
logger.info(
|
||||||
textwrap.dedent(
|
textwrap.dedent(
|
||||||
"""
|
"""
|
||||||
Llama Stack is composed of several APIs working together. For each API served by the Stack,
|
Llama Stack is composed of several APIs working together. For each API served by the Stack,
|
||||||
|
@ -76,18 +78,18 @@ def configure_api_providers(
|
||||||
|
|
||||||
existing_providers = config.providers.get(api_str, [])
|
existing_providers = config.providers.get(api_str, [])
|
||||||
if existing_providers:
|
if existing_providers:
|
||||||
cprint(
|
logger.info(
|
||||||
f"Re-configuring existing providers for API `{api_str}`...",
|
f"Re-configuring existing providers for API `{api_str}`...",
|
||||||
"green",
|
"green",
|
||||||
attrs=["bold"],
|
attrs=["bold"],
|
||||||
)
|
)
|
||||||
updated_providers = []
|
updated_providers = []
|
||||||
for p in existing_providers:
|
for p in existing_providers:
|
||||||
print(f"> Configuring provider `({p.provider_type})`")
|
logger.info(f"> Configuring provider `({p.provider_type})`")
|
||||||
updated_providers.append(
|
updated_providers.append(
|
||||||
configure_single_provider(provider_registry[api], p)
|
configure_single_provider(provider_registry[api], p)
|
||||||
)
|
)
|
||||||
print("")
|
logger.info("")
|
||||||
else:
|
else:
|
||||||
# we are newly configuring this API
|
# we are newly configuring this API
|
||||||
plist = build_spec.providers.get(api_str, [])
|
plist = build_spec.providers.get(api_str, [])
|
||||||
|
@ -96,17 +98,17 @@ def configure_api_providers(
|
||||||
if not plist:
|
if not plist:
|
||||||
raise ValueError(f"No provider configured for API {api_str}?")
|
raise ValueError(f"No provider configured for API {api_str}?")
|
||||||
|
|
||||||
cprint(f"Configuring API `{api_str}`...", "green", attrs=["bold"])
|
logger.info(f"Configuring API `{api_str}`...", "green", attrs=["bold"])
|
||||||
updated_providers = []
|
updated_providers = []
|
||||||
for i, provider_type in enumerate(plist):
|
for i, provider_type in enumerate(plist):
|
||||||
if i >= 1:
|
if i >= 1:
|
||||||
others = ", ".join(plist[i:])
|
others = ", ".join(plist[i:])
|
||||||
print(
|
logger.info(
|
||||||
f"Not configuring other providers ({others}) interactively. Please edit the resulting YAML directly.\n"
|
f"Not configuring other providers ({others}) interactively. Please edit the resulting YAML directly.\n"
|
||||||
)
|
)
|
||||||
break
|
break
|
||||||
|
|
||||||
print(f"> Configuring provider `({provider_type})`")
|
logger.info(f"> Configuring provider `({provider_type})`")
|
||||||
updated_providers.append(
|
updated_providers.append(
|
||||||
configure_single_provider(
|
configure_single_provider(
|
||||||
provider_registry[api],
|
provider_registry[api],
|
||||||
|
@ -121,7 +123,7 @@ def configure_api_providers(
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
print("")
|
logger.info("")
|
||||||
|
|
||||||
config.providers[api_str] = updated_providers
|
config.providers[api_str] = updated_providers
|
||||||
|
|
||||||
|
@ -182,7 +184,7 @@ def parse_and_maybe_upgrade_config(config_dict: Dict[str, Any]) -> StackRunConfi
|
||||||
return StackRunConfig(**config_dict)
|
return StackRunConfig(**config_dict)
|
||||||
|
|
||||||
if "routing_table" in config_dict:
|
if "routing_table" in config_dict:
|
||||||
print("Upgrading config...")
|
logger.info("Upgrading config...")
|
||||||
config_dict = upgrade_from_routing_table(config_dict)
|
config_dict = upgrade_from_routing_table(config_dict)
|
||||||
|
|
||||||
config_dict["version"] = LLAMA_STACK_RUN_CONFIG_VERSION
|
config_dict["version"] = LLAMA_STACK_RUN_CONFIG_VERSION
|
||||||
|
|
|
@ -5,11 +5,14 @@
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
import json
|
import json
|
||||||
|
import logging
|
||||||
import threading
|
import threading
|
||||||
from typing import Any, Dict
|
from typing import Any, Dict
|
||||||
|
|
||||||
from .utils.dynamic import instantiate_class_type
|
from .utils.dynamic import instantiate_class_type
|
||||||
|
|
||||||
|
log = logging.getLogger(__name__)
|
||||||
|
|
||||||
_THREAD_LOCAL = threading.local()
|
_THREAD_LOCAL = threading.local()
|
||||||
|
|
||||||
|
|
||||||
|
@ -32,7 +35,7 @@ class NeedsRequestProviderData:
|
||||||
provider_data = validator(**val)
|
provider_data = validator(**val)
|
||||||
return provider_data
|
return provider_data
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print("Error parsing provider data", e)
|
log.error("Error parsing provider data", e)
|
||||||
|
|
||||||
|
|
||||||
def set_request_provider_data(headers: Dict[str, str]):
|
def set_request_provider_data(headers: Dict[str, str]):
|
||||||
|
@ -51,7 +54,7 @@ def set_request_provider_data(headers: Dict[str, str]):
|
||||||
try:
|
try:
|
||||||
val = json.loads(val)
|
val = json.loads(val)
|
||||||
except json.JSONDecodeError:
|
except json.JSONDecodeError:
|
||||||
print("Provider data not encoded as a JSON object!", val)
|
log.error("Provider data not encoded as a JSON object!", val)
|
||||||
return
|
return
|
||||||
|
|
||||||
_THREAD_LOCAL.provider_data_header_value = val
|
_THREAD_LOCAL.provider_data_header_value = val
|
||||||
|
|
|
@ -8,11 +8,12 @@ import inspect
|
||||||
|
|
||||||
from typing import Any, Dict, List, Set
|
from typing import Any, Dict, List, Set
|
||||||
|
|
||||||
from termcolor import cprint
|
|
||||||
|
|
||||||
from llama_stack.providers.datatypes import * # noqa: F403
|
from llama_stack.providers.datatypes import * # noqa: F403
|
||||||
from llama_stack.distribution.datatypes import * # noqa: F403
|
from llama_stack.distribution.datatypes import * # noqa: F403
|
||||||
|
|
||||||
|
import logging
|
||||||
|
|
||||||
from llama_stack.apis.agents import Agents
|
from llama_stack.apis.agents import Agents
|
||||||
from llama_stack.apis.datasetio import DatasetIO
|
from llama_stack.apis.datasetio import DatasetIO
|
||||||
from llama_stack.apis.datasets import Datasets
|
from llama_stack.apis.datasets import Datasets
|
||||||
|
@ -33,6 +34,8 @@ from llama_stack.distribution.distribution import builtin_automatically_routed_a
|
||||||
from llama_stack.distribution.store import DistributionRegistry
|
from llama_stack.distribution.store import DistributionRegistry
|
||||||
from llama_stack.distribution.utils.dynamic import instantiate_class_type
|
from llama_stack.distribution.utils.dynamic import instantiate_class_type
|
||||||
|
|
||||||
|
log = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class InvalidProviderError(Exception):
|
class InvalidProviderError(Exception):
|
||||||
pass
|
pass
|
||||||
|
@ -115,11 +118,11 @@ async def resolve_impls(
|
||||||
|
|
||||||
p = provider_registry[api][provider.provider_type]
|
p = provider_registry[api][provider.provider_type]
|
||||||
if p.deprecation_error:
|
if p.deprecation_error:
|
||||||
cprint(p.deprecation_error, "red", attrs=["bold"])
|
log.error(p.deprecation_error, "red", attrs=["bold"])
|
||||||
raise InvalidProviderError(p.deprecation_error)
|
raise InvalidProviderError(p.deprecation_error)
|
||||||
|
|
||||||
elif p.deprecation_warning:
|
elif p.deprecation_warning:
|
||||||
cprint(
|
log.warning(
|
||||||
f"Provider `{provider.provider_type}` for API `{api}` is deprecated and will be removed in a future release: {p.deprecation_warning}",
|
f"Provider `{provider.provider_type}` for API `{api}` is deprecated and will be removed in a future release: {p.deprecation_warning}",
|
||||||
"yellow",
|
"yellow",
|
||||||
attrs=["bold"],
|
attrs=["bold"],
|
||||||
|
@ -199,10 +202,10 @@ async def resolve_impls(
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
print(f"Resolved {len(sorted_providers)} providers")
|
log.info(f"Resolved {len(sorted_providers)} providers")
|
||||||
for api_str, provider in sorted_providers:
|
for api_str, provider in sorted_providers:
|
||||||
print(f" {api_str} => {provider.provider_id}")
|
log.info(f" {api_str} => {provider.provider_id}")
|
||||||
print("")
|
log.info("")
|
||||||
|
|
||||||
impls = {}
|
impls = {}
|
||||||
inner_impls_by_provider_id = {f"inner-{x.value}": {} for x in router_apis}
|
inner_impls_by_provider_id = {f"inner-{x.value}": {} for x in router_apis}
|
||||||
|
@ -339,7 +342,7 @@ def check_protocol_compliance(obj: Any, protocol: Any) -> None:
|
||||||
obj_params = set(obj_sig.parameters)
|
obj_params = set(obj_sig.parameters)
|
||||||
obj_params.discard("self")
|
obj_params.discard("self")
|
||||||
if not (proto_params <= obj_params):
|
if not (proto_params <= obj_params):
|
||||||
print(
|
log.error(
|
||||||
f"Method {name} incompatible proto: {proto_params} vs. obj: {obj_params}"
|
f"Method {name} incompatible proto: {proto_params} vs. obj: {obj_params}"
|
||||||
)
|
)
|
||||||
missing_methods.append((name, "signature_mismatch"))
|
missing_methods.append((name, "signature_mismatch"))
|
||||||
|
|
|
@ -46,6 +46,10 @@ from llama_stack.distribution.stack import (
|
||||||
replace_env_vars,
|
replace_env_vars,
|
||||||
validate_env_pair,
|
validate_env_pair,
|
||||||
)
|
)
|
||||||
|
from llama_stack.providers.inline.meta_reference.telemetry.console import (
|
||||||
|
ConsoleConfig,
|
||||||
|
ConsoleTelemetryImpl,
|
||||||
|
)
|
||||||
|
|
||||||
from .endpoints import get_all_api_endpoints
|
from .endpoints import get_all_api_endpoints
|
||||||
|
|
||||||
|
@ -196,7 +200,6 @@ def handle_sigint(app, *args, **kwargs):
|
||||||
async def lifespan(app: FastAPI):
|
async def lifespan(app: FastAPI):
|
||||||
print("Starting up")
|
print("Starting up")
|
||||||
yield
|
yield
|
||||||
|
|
||||||
print("Shutting down")
|
print("Shutting down")
|
||||||
for impl in app.__llama_stack_impls__.values():
|
for impl in app.__llama_stack_impls__.values():
|
||||||
await impl.shutdown()
|
await impl.shutdown()
|
||||||
|
@ -214,6 +217,7 @@ async def maybe_await(value):
|
||||||
|
|
||||||
|
|
||||||
async def sse_generator(event_gen):
|
async def sse_generator(event_gen):
|
||||||
|
await start_trace("sse_generator")
|
||||||
try:
|
try:
|
||||||
event_gen = await event_gen
|
event_gen = await event_gen
|
||||||
async for item in event_gen:
|
async for item in event_gen:
|
||||||
|
@ -333,7 +337,7 @@ def main():
|
||||||
print("Run configuration:")
|
print("Run configuration:")
|
||||||
print(yaml.dump(config.model_dump(), indent=2))
|
print(yaml.dump(config.model_dump(), indent=2))
|
||||||
|
|
||||||
app = FastAPI()
|
app = FastAPI(lifespan=lifespan)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
impls = asyncio.run(construct_stack(config))
|
impls = asyncio.run(construct_stack(config))
|
||||||
|
@ -342,6 +346,8 @@ def main():
|
||||||
|
|
||||||
if Api.telemetry in impls:
|
if Api.telemetry in impls:
|
||||||
setup_logger(impls[Api.telemetry])
|
setup_logger(impls[Api.telemetry])
|
||||||
|
else:
|
||||||
|
setup_logger(ConsoleTelemetryImpl(ConsoleConfig()))
|
||||||
|
|
||||||
all_endpoints = get_all_api_endpoints()
|
all_endpoints = get_all_api_endpoints()
|
||||||
|
|
||||||
|
|
|
@ -4,6 +4,7 @@
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
import logging
|
||||||
import os
|
import os
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Any, Dict
|
from typing import Any, Dict
|
||||||
|
@ -40,6 +41,8 @@ from llama_stack.distribution.store.registry import create_dist_registry
|
||||||
from llama_stack.providers.datatypes import Api
|
from llama_stack.providers.datatypes import Api
|
||||||
|
|
||||||
|
|
||||||
|
log = logging.getLogger(__name__)
|
||||||
|
|
||||||
LLAMA_STACK_API_VERSION = "alpha"
|
LLAMA_STACK_API_VERSION = "alpha"
|
||||||
|
|
||||||
|
|
||||||
|
@ -93,11 +96,11 @@ async def register_resources(run_config: StackRunConfig, impls: Dict[Api, Any]):
|
||||||
|
|
||||||
method = getattr(impls[api], list_method)
|
method = getattr(impls[api], list_method)
|
||||||
for obj in await method():
|
for obj in await method():
|
||||||
print(
|
log.info(
|
||||||
f"{rsrc.capitalize()}: {colored(obj.identifier, 'white', attrs=['bold'])} served by {colored(obj.provider_id, 'white', attrs=['bold'])}",
|
f"{rsrc.capitalize()}: {colored(obj.identifier, 'white', attrs=['bold'])} served by {colored(obj.provider_id, 'white', attrs=['bold'])}",
|
||||||
)
|
)
|
||||||
|
|
||||||
print("")
|
log.info("")
|
||||||
|
|
||||||
|
|
||||||
class EnvVarError(Exception):
|
class EnvVarError(Exception):
|
||||||
|
|
|
@ -5,6 +5,7 @@
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
import errno
|
import errno
|
||||||
|
import logging
|
||||||
import os
|
import os
|
||||||
import pty
|
import pty
|
||||||
import select
|
import select
|
||||||
|
@ -13,7 +14,7 @@ import subprocess
|
||||||
import sys
|
import sys
|
||||||
import termios
|
import termios
|
||||||
|
|
||||||
from termcolor import cprint
|
log = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
# run a command in a pseudo-terminal, with interrupt handling,
|
# run a command in a pseudo-terminal, with interrupt handling,
|
||||||
|
@ -29,7 +30,7 @@ def run_with_pty(command):
|
||||||
def sigint_handler(signum, frame):
|
def sigint_handler(signum, frame):
|
||||||
nonlocal ctrl_c_pressed
|
nonlocal ctrl_c_pressed
|
||||||
ctrl_c_pressed = True
|
ctrl_c_pressed = True
|
||||||
cprint("\nCtrl-C detected. Aborting...", "white", attrs=["bold"])
|
log.info("\nCtrl-C detected. Aborting...")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# Set up the signal handler
|
# Set up the signal handler
|
||||||
|
@ -100,6 +101,6 @@ def run_command(command):
|
||||||
process = subprocess.Popen(command, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
|
process = subprocess.Popen(command, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
|
||||||
output, error = process.communicate()
|
output, error = process.communicate()
|
||||||
if process.returncode != 0:
|
if process.returncode != 0:
|
||||||
print(f"Error: {error.decode('utf-8')}")
|
log.error(f"Error: {error.decode('utf-8')}")
|
||||||
sys.exit(1)
|
sys.exit(1)
|
||||||
return output.decode("utf-8")
|
return output.decode("utf-8")
|
||||||
|
|
|
@ -6,6 +6,7 @@
|
||||||
|
|
||||||
import inspect
|
import inspect
|
||||||
import json
|
import json
|
||||||
|
import logging
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
|
|
||||||
from typing import Any, get_args, get_origin, List, Literal, Optional, Type, Union
|
from typing import Any, get_args, get_origin, List, Literal, Optional, Type, Union
|
||||||
|
@ -16,6 +17,8 @@ from pydantic_core import PydanticUndefinedType
|
||||||
|
|
||||||
from typing_extensions import Annotated
|
from typing_extensions import Annotated
|
||||||
|
|
||||||
|
log = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
def is_list_of_primitives(field_type):
|
def is_list_of_primitives(field_type):
|
||||||
"""Check if a field type is a List of primitive types."""
|
"""Check if a field type is a List of primitive types."""
|
||||||
|
@ -111,7 +114,7 @@ def prompt_for_discriminated_union(
|
||||||
|
|
||||||
if discriminator_value in type_map:
|
if discriminator_value in type_map:
|
||||||
chosen_type = type_map[discriminator_value]
|
chosen_type = type_map[discriminator_value]
|
||||||
print(f"\nConfiguring {chosen_type.__name__}:")
|
log.info(f"\nConfiguring {chosen_type.__name__}:")
|
||||||
|
|
||||||
if existing_value and (
|
if existing_value and (
|
||||||
getattr(existing_value, discriminator) != discriminator_value
|
getattr(existing_value, discriminator) != discriminator_value
|
||||||
|
@ -123,7 +126,7 @@ def prompt_for_discriminated_union(
|
||||||
setattr(sub_config, discriminator, discriminator_value)
|
setattr(sub_config, discriminator, discriminator_value)
|
||||||
return sub_config
|
return sub_config
|
||||||
else:
|
else:
|
||||||
print(f"Invalid {discriminator}. Please try again.")
|
log.error(f"Invalid {discriminator}. Please try again.")
|
||||||
|
|
||||||
|
|
||||||
# This is somewhat elaborate, but does not purport to be comprehensive in any way.
|
# This is somewhat elaborate, but does not purport to be comprehensive in any way.
|
||||||
|
@ -180,7 +183,7 @@ def prompt_for_config(
|
||||||
config_data[field_name] = validated_value
|
config_data[field_name] = validated_value
|
||||||
break
|
break
|
||||||
except KeyError:
|
except KeyError:
|
||||||
print(
|
log.error(
|
||||||
f"Invalid choice. Please choose from: {', '.join(e.name for e in field_type)}"
|
f"Invalid choice. Please choose from: {', '.join(e.name for e in field_type)}"
|
||||||
)
|
)
|
||||||
continue
|
continue
|
||||||
|
@ -197,7 +200,7 @@ def prompt_for_config(
|
||||||
config_data[field_name] = None
|
config_data[field_name] = None
|
||||||
continue
|
continue
|
||||||
nested_type = get_non_none_type(field_type)
|
nested_type = get_non_none_type(field_type)
|
||||||
print(f"Entering sub-configuration for {field_name}:")
|
log.info(f"Entering sub-configuration for {field_name}:")
|
||||||
config_data[field_name] = prompt_for_config(nested_type, existing_value)
|
config_data[field_name] = prompt_for_config(nested_type, existing_value)
|
||||||
elif is_optional(field_type) and is_discriminated_union(
|
elif is_optional(field_type) and is_discriminated_union(
|
||||||
get_non_none_type(field_type)
|
get_non_none_type(field_type)
|
||||||
|
@ -213,7 +216,7 @@ def prompt_for_config(
|
||||||
existing_value,
|
existing_value,
|
||||||
)
|
)
|
||||||
elif can_recurse(field_type):
|
elif can_recurse(field_type):
|
||||||
print(f"\nEntering sub-configuration for {field_name}:")
|
log.info(f"\nEntering sub-configuration for {field_name}:")
|
||||||
config_data[field_name] = prompt_for_config(
|
config_data[field_name] = prompt_for_config(
|
||||||
field_type,
|
field_type,
|
||||||
existing_value,
|
existing_value,
|
||||||
|
@ -240,7 +243,7 @@ def prompt_for_config(
|
||||||
config_data[field_name] = None
|
config_data[field_name] = None
|
||||||
break
|
break
|
||||||
else:
|
else:
|
||||||
print("This field is required. Please provide a value.")
|
log.error("This field is required. Please provide a value.")
|
||||||
continue
|
continue
|
||||||
else:
|
else:
|
||||||
try:
|
try:
|
||||||
|
@ -264,12 +267,12 @@ def prompt_for_config(
|
||||||
value = [element_type(item) for item in value]
|
value = [element_type(item) for item in value]
|
||||||
|
|
||||||
except json.JSONDecodeError:
|
except json.JSONDecodeError:
|
||||||
print(
|
log.error(
|
||||||
'Invalid JSON. Please enter a valid JSON-encoded list e.g., ["foo","bar"]'
|
'Invalid JSON. Please enter a valid JSON-encoded list e.g., ["foo","bar"]'
|
||||||
)
|
)
|
||||||
continue
|
continue
|
||||||
except ValueError as e:
|
except ValueError as e:
|
||||||
print(f"{str(e)}")
|
log.error(f"{str(e)}")
|
||||||
continue
|
continue
|
||||||
|
|
||||||
elif get_origin(field_type) is dict:
|
elif get_origin(field_type) is dict:
|
||||||
|
@ -281,7 +284,7 @@ def prompt_for_config(
|
||||||
)
|
)
|
||||||
|
|
||||||
except json.JSONDecodeError:
|
except json.JSONDecodeError:
|
||||||
print(
|
log.error(
|
||||||
"Invalid JSON. Please enter a valid JSON-encoded dict."
|
"Invalid JSON. Please enter a valid JSON-encoded dict."
|
||||||
)
|
)
|
||||||
continue
|
continue
|
||||||
|
@ -298,7 +301,7 @@ def prompt_for_config(
|
||||||
value = field_type(user_input)
|
value = field_type(user_input)
|
||||||
|
|
||||||
except ValueError:
|
except ValueError:
|
||||||
print(
|
log.error(
|
||||||
f"Invalid input. Expected type: {getattr(field_type, '__name__', str(field_type))}"
|
f"Invalid input. Expected type: {getattr(field_type, '__name__', str(field_type))}"
|
||||||
)
|
)
|
||||||
continue
|
continue
|
||||||
|
@ -311,6 +314,6 @@ def prompt_for_config(
|
||||||
config_data[field_name] = validated_value
|
config_data[field_name] = validated_value
|
||||||
break
|
break
|
||||||
except ValueError as e:
|
except ValueError as e:
|
||||||
print(f"Validation error: {str(e)}")
|
log.error(f"Validation error: {str(e)}")
|
||||||
|
|
||||||
return config_type(**config_data)
|
return config_type(**config_data)
|
||||||
|
|
|
@ -6,6 +6,7 @@
|
||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
import copy
|
import copy
|
||||||
|
import logging
|
||||||
import os
|
import os
|
||||||
import re
|
import re
|
||||||
import secrets
|
import secrets
|
||||||
|
@ -19,7 +20,6 @@ from urllib.parse import urlparse
|
||||||
|
|
||||||
import httpx
|
import httpx
|
||||||
|
|
||||||
from termcolor import cprint
|
|
||||||
|
|
||||||
from llama_stack.apis.agents import * # noqa: F403
|
from llama_stack.apis.agents import * # noqa: F403
|
||||||
from llama_stack.apis.inference import * # noqa: F403
|
from llama_stack.apis.inference import * # noqa: F403
|
||||||
|
@ -43,6 +43,8 @@ from .tools.builtin import (
|
||||||
)
|
)
|
||||||
from .tools.safety import SafeTool
|
from .tools.safety import SafeTool
|
||||||
|
|
||||||
|
log = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
def make_random_string(length: int = 8):
|
def make_random_string(length: int = 8):
|
||||||
return "".join(
|
return "".join(
|
||||||
|
@ -137,7 +139,6 @@ class ChatAgent(ShieldRunnerMixin):
|
||||||
stop_reason=StopReason.end_of_turn,
|
stop_reason=StopReason.end_of_turn,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
# print_dialog(messages)
|
|
||||||
return messages
|
return messages
|
||||||
|
|
||||||
async def create_session(self, name: str) -> str:
|
async def create_session(self, name: str) -> str:
|
||||||
|
@ -185,10 +186,8 @@ class ChatAgent(ShieldRunnerMixin):
|
||||||
stream=request.stream,
|
stream=request.stream,
|
||||||
):
|
):
|
||||||
if isinstance(chunk, CompletionMessage):
|
if isinstance(chunk, CompletionMessage):
|
||||||
cprint(
|
log.info(
|
||||||
f"{chunk.role.capitalize()}: {chunk.content}",
|
f"{chunk.role.capitalize()}: {chunk.content}",
|
||||||
"white",
|
|
||||||
attrs=["bold"],
|
|
||||||
)
|
)
|
||||||
output_message = chunk
|
output_message = chunk
|
||||||
continue
|
continue
|
||||||
|
@ -407,7 +406,7 @@ class ChatAgent(ShieldRunnerMixin):
|
||||||
msg_str = f"{str(msg)[:500]}...<more>...{str(msg)[-500:]}"
|
msg_str = f"{str(msg)[:500]}...<more>...{str(msg)[-500:]}"
|
||||||
else:
|
else:
|
||||||
msg_str = str(msg)
|
msg_str = str(msg)
|
||||||
cprint(f"{msg_str}", color=color)
|
log.info(f"{msg_str}")
|
||||||
|
|
||||||
step_id = str(uuid.uuid4())
|
step_id = str(uuid.uuid4())
|
||||||
yield AgentTurnResponseStreamChunk(
|
yield AgentTurnResponseStreamChunk(
|
||||||
|
@ -506,12 +505,12 @@ class ChatAgent(ShieldRunnerMixin):
|
||||||
)
|
)
|
||||||
|
|
||||||
if n_iter >= self.agent_config.max_infer_iters:
|
if n_iter >= self.agent_config.max_infer_iters:
|
||||||
cprint("Done with MAX iterations, exiting.")
|
log.info("Done with MAX iterations, exiting.")
|
||||||
yield message
|
yield message
|
||||||
break
|
break
|
||||||
|
|
||||||
if stop_reason == StopReason.out_of_tokens:
|
if stop_reason == StopReason.out_of_tokens:
|
||||||
cprint("Out of token budget, exiting.")
|
log.info("Out of token budget, exiting.")
|
||||||
yield message
|
yield message
|
||||||
break
|
break
|
||||||
|
|
||||||
|
@ -525,10 +524,10 @@ class ChatAgent(ShieldRunnerMixin):
|
||||||
message.content = [message.content] + attachments
|
message.content = [message.content] + attachments
|
||||||
yield message
|
yield message
|
||||||
else:
|
else:
|
||||||
cprint(f"Partial message: {str(message)}", color="green")
|
log.info(f"Partial message: {str(message)}", color="green")
|
||||||
input_messages = input_messages + [message]
|
input_messages = input_messages + [message]
|
||||||
else:
|
else:
|
||||||
cprint(f"{str(message)}", color="green")
|
log.info(f"{str(message)}", color="green")
|
||||||
try:
|
try:
|
||||||
tool_call = message.tool_calls[0]
|
tool_call = message.tool_calls[0]
|
||||||
|
|
||||||
|
@ -740,9 +739,8 @@ class ChatAgent(ShieldRunnerMixin):
|
||||||
for c in chunks[: memory.max_chunks]:
|
for c in chunks[: memory.max_chunks]:
|
||||||
tokens += c.token_count
|
tokens += c.token_count
|
||||||
if tokens > memory.max_tokens_in_context:
|
if tokens > memory.max_tokens_in_context:
|
||||||
cprint(
|
log.error(
|
||||||
f"Using {len(picked)} chunks; reached max tokens in context: {tokens}",
|
f"Using {len(picked)} chunks; reached max tokens in context: {tokens}",
|
||||||
"red",
|
|
||||||
)
|
)
|
||||||
break
|
break
|
||||||
picked.append(f"id:{c.document_id}; content:{c.content}")
|
picked.append(f"id:{c.document_id}; content:{c.content}")
|
||||||
|
@ -786,7 +784,7 @@ async def attachment_message(tempdir: str, urls: List[URL]) -> ToolResponseMessa
|
||||||
path = urlparse(uri).path
|
path = urlparse(uri).path
|
||||||
basename = os.path.basename(path)
|
basename = os.path.basename(path)
|
||||||
filepath = f"{tempdir}/{make_random_string() + basename}"
|
filepath = f"{tempdir}/{make_random_string() + basename}"
|
||||||
print(f"Downloading {url} -> {filepath}")
|
log.info(f"Downloading {url} -> {filepath}")
|
||||||
|
|
||||||
async with httpx.AsyncClient() as client:
|
async with httpx.AsyncClient() as client:
|
||||||
r = await client.get(uri)
|
r = await client.get(uri)
|
||||||
|
@ -826,20 +824,3 @@ async def execute_tool_call_maybe(
|
||||||
tool = tools_dict[name]
|
tool = tools_dict[name]
|
||||||
result_messages = await tool.run(messages)
|
result_messages = await tool.run(messages)
|
||||||
return result_messages
|
return result_messages
|
||||||
|
|
||||||
|
|
||||||
def print_dialog(messages: List[Message]):
|
|
||||||
for i, m in enumerate(messages):
|
|
||||||
if m.role == Role.user.value:
|
|
||||||
color = "red"
|
|
||||||
elif m.role == Role.assistant.value:
|
|
||||||
color = "white"
|
|
||||||
elif m.role == Role.ipython.value:
|
|
||||||
color = "yellow"
|
|
||||||
elif m.role == Role.system.value:
|
|
||||||
color = "green"
|
|
||||||
else:
|
|
||||||
color = "white"
|
|
||||||
|
|
||||||
s = str(m)
|
|
||||||
cprint(f"{i} ::: {s[:100]}...", color=color)
|
|
||||||
|
|
|
@ -5,7 +5,7 @@
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
import json
|
import json
|
||||||
|
import logging
|
||||||
import uuid
|
import uuid
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
|
|
||||||
|
@ -15,6 +15,8 @@ from pydantic import BaseModel
|
||||||
|
|
||||||
from llama_stack.providers.utils.kvstore import KVStore
|
from llama_stack.providers.utils.kvstore import KVStore
|
||||||
|
|
||||||
|
log = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class AgentSessionInfo(BaseModel):
|
class AgentSessionInfo(BaseModel):
|
||||||
session_id: str
|
session_id: str
|
||||||
|
@ -78,7 +80,7 @@ class AgentPersistence:
|
||||||
turn = Turn(**json.loads(value))
|
turn = Turn(**json.loads(value))
|
||||||
turns.append(turn)
|
turns.append(turn)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"Error parsing turn: {e}")
|
log.error(f"Error parsing turn: {e}")
|
||||||
continue
|
continue
|
||||||
turns.sort(key=lambda x: (x.completed_at or datetime.min))
|
turns.sort(key=lambda x: (x.completed_at or datetime.min))
|
||||||
return turns
|
return turns
|
||||||
|
|
|
@ -10,8 +10,6 @@ from jinja2 import Template
|
||||||
from llama_models.llama3.api import * # noqa: F403
|
from llama_models.llama3.api import * # noqa: F403
|
||||||
|
|
||||||
|
|
||||||
from termcolor import cprint # noqa: F401
|
|
||||||
|
|
||||||
from llama_stack.apis.agents import (
|
from llama_stack.apis.agents import (
|
||||||
DefaultMemoryQueryGeneratorConfig,
|
DefaultMemoryQueryGeneratorConfig,
|
||||||
LLMMemoryQueryGeneratorConfig,
|
LLMMemoryQueryGeneratorConfig,
|
||||||
|
@ -36,7 +34,6 @@ async def generate_rag_query(
|
||||||
query = await llm_rag_query_generator(config, messages, **kwargs)
|
query = await llm_rag_query_generator(config, messages, **kwargs)
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError(f"Unsupported memory query generator {config.type}")
|
raise NotImplementedError(f"Unsupported memory query generator {config.type}")
|
||||||
# cprint(f"Generated query >>>: {query}", color="green")
|
|
||||||
return query
|
return query
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -5,14 +5,16 @@
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
|
import logging
|
||||||
|
|
||||||
from typing import List
|
from typing import List
|
||||||
|
|
||||||
from llama_models.llama3.api.datatypes import Message
|
from llama_models.llama3.api.datatypes import Message
|
||||||
from termcolor import cprint
|
|
||||||
|
|
||||||
from llama_stack.apis.safety import * # noqa: F403
|
from llama_stack.apis.safety import * # noqa: F403
|
||||||
|
|
||||||
|
log = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class SafetyException(Exception): # noqa: N818
|
class SafetyException(Exception): # noqa: N818
|
||||||
def __init__(self, violation: SafetyViolation):
|
def __init__(self, violation: SafetyViolation):
|
||||||
|
@ -51,7 +53,4 @@ class ShieldRunnerMixin:
|
||||||
if violation.violation_level == ViolationLevel.ERROR:
|
if violation.violation_level == ViolationLevel.ERROR:
|
||||||
raise SafetyException(violation)
|
raise SafetyException(violation)
|
||||||
elif violation.violation_level == ViolationLevel.WARN:
|
elif violation.violation_level == ViolationLevel.WARN:
|
||||||
cprint(
|
log.warning(f"[Warn]{identifier} raised a warning")
|
||||||
f"[Warn]{identifier} raised a warning",
|
|
||||||
color="red",
|
|
||||||
)
|
|
||||||
|
|
|
@ -5,6 +5,7 @@
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
import json
|
import json
|
||||||
|
import logging
|
||||||
import re
|
import re
|
||||||
import tempfile
|
import tempfile
|
||||||
|
|
||||||
|
@ -12,7 +13,6 @@ from abc import abstractmethod
|
||||||
from typing import List, Optional
|
from typing import List, Optional
|
||||||
|
|
||||||
import requests
|
import requests
|
||||||
from termcolor import cprint
|
|
||||||
|
|
||||||
from .ipython_tool.code_execution import (
|
from .ipython_tool.code_execution import (
|
||||||
CodeExecutionContext,
|
CodeExecutionContext,
|
||||||
|
@ -27,6 +27,9 @@ from llama_stack.apis.agents import * # noqa: F403
|
||||||
from .base import BaseTool
|
from .base import BaseTool
|
||||||
|
|
||||||
|
|
||||||
|
log = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
def interpret_content_as_attachment(content: str) -> Optional[Attachment]:
|
def interpret_content_as_attachment(content: str) -> Optional[Attachment]:
|
||||||
match = re.search(TOOLS_ATTACHMENT_KEY_REGEX, content)
|
match = re.search(TOOLS_ATTACHMENT_KEY_REGEX, content)
|
||||||
if match:
|
if match:
|
||||||
|
@ -383,7 +386,7 @@ class CodeInterpreterTool(BaseTool):
|
||||||
if res_out != "":
|
if res_out != "":
|
||||||
pieces.extend([f"[{out_type}]", res_out, f"[/{out_type}]"])
|
pieces.extend([f"[{out_type}]", res_out, f"[/{out_type}]"])
|
||||||
if out_type == "stderr":
|
if out_type == "stderr":
|
||||||
cprint(f"ipython tool error: ↓\n{res_out}", color="red")
|
log.error(f"ipython tool error: ↓\n{res_out}")
|
||||||
|
|
||||||
message = ToolResponseMessage(
|
message = ToolResponseMessage(
|
||||||
call_id=tool_call.call_id,
|
call_id=tool_call.call_id,
|
||||||
|
|
|
@ -11,6 +11,7 @@ A custom Matplotlib backend that overrides the show method to return image bytes
|
||||||
import base64
|
import base64
|
||||||
import io
|
import io
|
||||||
import json as _json
|
import json as _json
|
||||||
|
import logging
|
||||||
|
|
||||||
import matplotlib
|
import matplotlib
|
||||||
from matplotlib.backend_bases import FigureManagerBase
|
from matplotlib.backend_bases import FigureManagerBase
|
||||||
|
@ -18,6 +19,8 @@ from matplotlib.backend_bases import FigureManagerBase
|
||||||
# Import necessary components from Matplotlib
|
# Import necessary components from Matplotlib
|
||||||
from matplotlib.backends.backend_agg import FigureCanvasAgg
|
from matplotlib.backends.backend_agg import FigureCanvasAgg
|
||||||
|
|
||||||
|
log = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class CustomFigureCanvas(FigureCanvasAgg):
|
class CustomFigureCanvas(FigureCanvasAgg):
|
||||||
def show(self):
|
def show(self):
|
||||||
|
@ -80,7 +83,7 @@ def show():
|
||||||
)
|
)
|
||||||
req_con.send_bytes(_json_dump.encode("utf-8"))
|
req_con.send_bytes(_json_dump.encode("utf-8"))
|
||||||
resp = _json.loads(resp_con.recv_bytes().decode("utf-8"))
|
resp = _json.loads(resp_con.recv_bytes().decode("utf-8"))
|
||||||
print(resp)
|
log.info(resp)
|
||||||
|
|
||||||
|
|
||||||
FigureCanvas = CustomFigureCanvas
|
FigureCanvas = CustomFigureCanvas
|
||||||
|
|
|
@ -8,6 +8,7 @@
|
||||||
# This software may be used and distributed in accordance with the terms of the Llama 3 Community License Agreement.
|
# This software may be used and distributed in accordance with the terms of the Llama 3 Community License Agreement.
|
||||||
|
|
||||||
import json
|
import json
|
||||||
|
import logging
|
||||||
import math
|
import math
|
||||||
import os
|
import os
|
||||||
import sys
|
import sys
|
||||||
|
@ -31,7 +32,6 @@ from llama_models.llama3.reference_impl.multimodal.model import (
|
||||||
)
|
)
|
||||||
from llama_models.sku_list import resolve_model
|
from llama_models.sku_list import resolve_model
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
from termcolor import cprint
|
|
||||||
|
|
||||||
from llama_stack.apis.inference import * # noqa: F403
|
from llama_stack.apis.inference import * # noqa: F403
|
||||||
|
|
||||||
|
@ -50,6 +50,8 @@ from .config import (
|
||||||
MetaReferenceQuantizedInferenceConfig,
|
MetaReferenceQuantizedInferenceConfig,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
log = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
def model_checkpoint_dir(model) -> str:
|
def model_checkpoint_dir(model) -> str:
|
||||||
checkpoint_dir = Path(model_local_dir(model.descriptor()))
|
checkpoint_dir = Path(model_local_dir(model.descriptor()))
|
||||||
|
@ -185,7 +187,7 @@ class Llama:
|
||||||
model = Transformer(model_args)
|
model = Transformer(model_args)
|
||||||
model.load_state_dict(state_dict, strict=False)
|
model.load_state_dict(state_dict, strict=False)
|
||||||
|
|
||||||
print(f"Loaded in {time.time() - start_time:.2f} seconds")
|
log.info(f"Loaded in {time.time() - start_time:.2f} seconds")
|
||||||
return Llama(model, tokenizer, model_args, llama_model)
|
return Llama(model, tokenizer, model_args, llama_model)
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
|
@ -221,7 +223,7 @@ class Llama:
|
||||||
self.formatter.vision_token if t == 128256 else t
|
self.formatter.vision_token if t == 128256 else t
|
||||||
for t in model_input.tokens
|
for t in model_input.tokens
|
||||||
]
|
]
|
||||||
cprint("Input to model -> " + self.tokenizer.decode(input_tokens), "red")
|
log.info("Input to model -> " + self.tokenizer.decode(input_tokens))
|
||||||
prompt_tokens = [model_input.tokens]
|
prompt_tokens = [model_input.tokens]
|
||||||
|
|
||||||
bsz = 1
|
bsz = 1
|
||||||
|
@ -231,9 +233,7 @@ class Llama:
|
||||||
max_prompt_len = max(len(t) for t in prompt_tokens)
|
max_prompt_len = max(len(t) for t in prompt_tokens)
|
||||||
|
|
||||||
if max_prompt_len >= params.max_seq_len:
|
if max_prompt_len >= params.max_seq_len:
|
||||||
cprint(
|
log.error(f"Out of token budget {max_prompt_len} vs {params.max_seq_len}")
|
||||||
f"Out of token budget {max_prompt_len} vs {params.max_seq_len}", "red"
|
|
||||||
)
|
|
||||||
return
|
return
|
||||||
|
|
||||||
total_len = min(max_gen_len + max_prompt_len, params.max_seq_len)
|
total_len = min(max_gen_len + max_prompt_len, params.max_seq_len)
|
||||||
|
|
|
@ -5,6 +5,7 @@
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
|
import logging
|
||||||
|
|
||||||
from typing import AsyncGenerator, List
|
from typing import AsyncGenerator, List
|
||||||
|
|
||||||
|
@ -25,6 +26,7 @@ from .config import MetaReferenceInferenceConfig
|
||||||
from .generation import Llama
|
from .generation import Llama
|
||||||
from .model_parallel import LlamaModelParallelGenerator
|
from .model_parallel import LlamaModelParallelGenerator
|
||||||
|
|
||||||
|
log = logging.getLogger(__name__)
|
||||||
# there's a single model parallel process running serving the model. for now,
|
# there's a single model parallel process running serving the model. for now,
|
||||||
# we don't support multiple concurrent requests to this process.
|
# we don't support multiple concurrent requests to this process.
|
||||||
SEMAPHORE = asyncio.Semaphore(1)
|
SEMAPHORE = asyncio.Semaphore(1)
|
||||||
|
@ -49,7 +51,7 @@ class MetaReferenceInferenceImpl(Inference, ModelRegistryHelper, ModelsProtocolP
|
||||||
# verify that the checkpoint actually is for this model lol
|
# verify that the checkpoint actually is for this model lol
|
||||||
|
|
||||||
async def initialize(self) -> None:
|
async def initialize(self) -> None:
|
||||||
print(f"Loading model `{self.model.descriptor()}`")
|
log.info(f"Loading model `{self.model.descriptor()}`")
|
||||||
if self.config.create_distributed_process_group:
|
if self.config.create_distributed_process_group:
|
||||||
self.generator = LlamaModelParallelGenerator(self.config)
|
self.generator = LlamaModelParallelGenerator(self.config)
|
||||||
self.generator.start()
|
self.generator.start()
|
||||||
|
|
|
@ -11,6 +11,7 @@
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
import json
|
import json
|
||||||
|
import logging
|
||||||
import multiprocessing
|
import multiprocessing
|
||||||
import os
|
import os
|
||||||
import tempfile
|
import tempfile
|
||||||
|
@ -37,6 +38,8 @@ from llama_stack.apis.inference import ChatCompletionRequest, CompletionRequest
|
||||||
|
|
||||||
from .generation import TokenResult
|
from .generation import TokenResult
|
||||||
|
|
||||||
|
log = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class ProcessingMessageName(str, Enum):
|
class ProcessingMessageName(str, Enum):
|
||||||
ready_request = "ready_request"
|
ready_request = "ready_request"
|
||||||
|
@ -183,16 +186,16 @@ def retrieve_requests(reply_socket_url: str):
|
||||||
group=get_model_parallel_group(),
|
group=get_model_parallel_group(),
|
||||||
)
|
)
|
||||||
if isinstance(updates[0], CancelSentinel):
|
if isinstance(updates[0], CancelSentinel):
|
||||||
print("quitting generation loop because request was cancelled")
|
log.info(
|
||||||
|
"quitting generation loop because request was cancelled"
|
||||||
|
)
|
||||||
break
|
break
|
||||||
|
|
||||||
if mp_rank_0():
|
if mp_rank_0():
|
||||||
send_obj(EndSentinel())
|
send_obj(EndSentinel())
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"[debug] got exception {e}")
|
log.exception("exception in generation loop")
|
||||||
import traceback
|
|
||||||
|
|
||||||
traceback.print_exc()
|
|
||||||
if mp_rank_0():
|
if mp_rank_0():
|
||||||
send_obj(ExceptionResponse(error=str(e)))
|
send_obj(ExceptionResponse(error=str(e)))
|
||||||
|
|
||||||
|
@ -252,7 +255,7 @@ def worker_process_entrypoint(
|
||||||
except StopIteration:
|
except StopIteration:
|
||||||
break
|
break
|
||||||
|
|
||||||
print("[debug] worker process done")
|
log.info("[debug] worker process done")
|
||||||
|
|
||||||
|
|
||||||
def launch_dist_group(
|
def launch_dist_group(
|
||||||
|
@ -313,7 +316,7 @@ def start_model_parallel_process(
|
||||||
|
|
||||||
request_socket.send(encode_msg(ReadyRequest()))
|
request_socket.send(encode_msg(ReadyRequest()))
|
||||||
response = request_socket.recv()
|
response = request_socket.recv()
|
||||||
print("Loaded model...")
|
log.info("Loaded model...")
|
||||||
|
|
||||||
return request_socket, process
|
return request_socket, process
|
||||||
|
|
||||||
|
@ -361,7 +364,7 @@ class ModelParallelProcessGroup:
|
||||||
break
|
break
|
||||||
|
|
||||||
if isinstance(obj, ExceptionResponse):
|
if isinstance(obj, ExceptionResponse):
|
||||||
print(f"[debug] got exception {obj.error}")
|
log.error(f"[debug] got exception {obj.error}")
|
||||||
raise Exception(obj.error)
|
raise Exception(obj.error)
|
||||||
|
|
||||||
if isinstance(obj, TaskResponse):
|
if isinstance(obj, TaskResponse):
|
||||||
|
|
|
@ -8,14 +8,20 @@
|
||||||
# This software may be used and distributed in accordance with the terms of the Llama 3 Community License Agreement.
|
# This software may be used and distributed in accordance with the terms of the Llama 3 Community License Agreement.
|
||||||
|
|
||||||
import collections
|
import collections
|
||||||
|
|
||||||
|
import logging
|
||||||
from typing import Optional, Type
|
from typing import Optional, Type
|
||||||
|
|
||||||
|
log = logging.getLogger(__name__)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
import fbgemm_gpu.experimental.gen_ai # noqa: F401
|
import fbgemm_gpu.experimental.gen_ai # noqa: F401
|
||||||
|
|
||||||
print("Using efficient FP8 operators in FBGEMM.")
|
log.info("Using efficient FP8 operators in FBGEMM.")
|
||||||
except ImportError:
|
except ImportError:
|
||||||
print("No efficient FP8 operators. Please install FBGEMM in fp8_requirements.txt.")
|
log.error(
|
||||||
|
"No efficient FP8 operators. Please install FBGEMM in fp8_requirements.txt."
|
||||||
|
)
|
||||||
raise
|
raise
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
|
@ -7,6 +7,7 @@
|
||||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
# This software may be used and distributed in accordance with the terms of the Llama 3 Community License Agreement.
|
# This software may be used and distributed in accordance with the terms of the Llama 3 Community License Agreement.
|
||||||
|
|
||||||
|
import logging
|
||||||
import os
|
import os
|
||||||
from typing import Any, Dict, List, Optional
|
from typing import Any, Dict, List, Optional
|
||||||
|
|
||||||
|
@ -21,7 +22,6 @@ from llama_models.llama3.api.args import ModelArgs
|
||||||
from llama_models.llama3.reference_impl.model import Transformer, TransformerBlock
|
from llama_models.llama3.reference_impl.model import Transformer, TransformerBlock
|
||||||
from llama_models.sku_list import resolve_model
|
from llama_models.sku_list import resolve_model
|
||||||
|
|
||||||
from termcolor import cprint
|
|
||||||
from torch import nn, Tensor
|
from torch import nn, Tensor
|
||||||
|
|
||||||
from torchao.quantization.GPTQ import Int8DynActInt4WeightLinear
|
from torchao.quantization.GPTQ import Int8DynActInt4WeightLinear
|
||||||
|
@ -30,6 +30,8 @@ from llama_stack.apis.inference import QuantizationType
|
||||||
|
|
||||||
from ..config import MetaReferenceQuantizedInferenceConfig
|
from ..config import MetaReferenceQuantizedInferenceConfig
|
||||||
|
|
||||||
|
log = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
def swiglu_wrapper(
|
def swiglu_wrapper(
|
||||||
self,
|
self,
|
||||||
|
@ -60,7 +62,7 @@ def convert_to_fp8_quantized_model(
|
||||||
|
|
||||||
# Move weights to GPU with quantization
|
# Move weights to GPU with quantization
|
||||||
if llama_model.quantization_format == CheckpointQuantizationFormat.fp8_mixed.value:
|
if llama_model.quantization_format == CheckpointQuantizationFormat.fp8_mixed.value:
|
||||||
cprint("Loading fp8 scales...", "yellow")
|
log.info("Loading fp8 scales...")
|
||||||
fp8_scales_path = os.path.join(
|
fp8_scales_path = os.path.join(
|
||||||
checkpoint_dir, f"fp8_scales_{get_model_parallel_rank()}.pt"
|
checkpoint_dir, f"fp8_scales_{get_model_parallel_rank()}.pt"
|
||||||
)
|
)
|
||||||
|
@ -85,7 +87,7 @@ def convert_to_fp8_quantized_model(
|
||||||
fp8_activation_scale_ub,
|
fp8_activation_scale_ub,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
cprint("Quantizing fp8 weights from bf16...", "yellow")
|
log.info("Quantizing fp8 weights from bf16...")
|
||||||
for block in model.layers:
|
for block in model.layers:
|
||||||
if isinstance(block, TransformerBlock):
|
if isinstance(block, TransformerBlock):
|
||||||
if block.layer_id == 0 or block.layer_id == (model.n_layers - 1):
|
if block.layer_id == 0 or block.layer_id == (model.n_layers - 1):
|
||||||
|
|
|
@ -8,6 +8,7 @@
|
||||||
# This software may be used and distributed in accordance with the terms of the Llama 3 Community License Agreement.
|
# This software may be used and distributed in accordance with the terms of the Llama 3 Community License Agreement.
|
||||||
|
|
||||||
import json
|
import json
|
||||||
|
import logging
|
||||||
import os
|
import os
|
||||||
import shutil
|
import shutil
|
||||||
import sys
|
import sys
|
||||||
|
@ -32,6 +33,8 @@ from llama_stack.providers.inline.inference.meta_reference.quantization.fp8_impl
|
||||||
quantize_fp8,
|
quantize_fp8,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
log = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
def main(
|
def main(
|
||||||
ckpt_dir: str,
|
ckpt_dir: str,
|
||||||
|
@ -102,7 +105,7 @@ def main(
|
||||||
else:
|
else:
|
||||||
torch.set_default_tensor_type(torch.cuda.HalfTensor)
|
torch.set_default_tensor_type(torch.cuda.HalfTensor)
|
||||||
|
|
||||||
print(ckpt_path)
|
log.info(ckpt_path)
|
||||||
assert (
|
assert (
|
||||||
quantized_ckpt_dir is not None
|
quantized_ckpt_dir is not None
|
||||||
), "QUantized checkpoint directory should not be None"
|
), "QUantized checkpoint directory should not be None"
|
||||||
|
|
|
@ -4,10 +4,18 @@
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
from enum import Enum
|
||||||
|
|
||||||
from llama_models.schema_utils import json_schema_type
|
from llama_models.schema_utils import json_schema_type
|
||||||
|
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
|
||||||
|
class LogFormat(Enum):
|
||||||
|
TEXT = "text"
|
||||||
|
JSON = "json"
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
class ConsoleConfig(BaseModel): ...
|
class ConsoleConfig(BaseModel):
|
||||||
|
log_format: LogFormat = LogFormat.JSON
|
||||||
|
|
|
@ -4,8 +4,11 @@
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
import json
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
|
from .config import LogFormat
|
||||||
|
|
||||||
from llama_stack.apis.telemetry import * # noqa: F403
|
from llama_stack.apis.telemetry import * # noqa: F403
|
||||||
from .config import ConsoleConfig
|
from .config import ConsoleConfig
|
||||||
|
|
||||||
|
@ -38,7 +41,11 @@ class ConsoleTelemetryImpl(Telemetry):
|
||||||
|
|
||||||
span_name = ".".join(names) if names else None
|
span_name = ".".join(names) if names else None
|
||||||
|
|
||||||
formatted = format_event(event, span_name)
|
if self.config.log_format == LogFormat.JSON:
|
||||||
|
formatted = format_event_json(event, span_name)
|
||||||
|
else:
|
||||||
|
formatted = format_event_text(event, span_name)
|
||||||
|
|
||||||
if formatted:
|
if formatted:
|
||||||
print(formatted)
|
print(formatted)
|
||||||
|
|
||||||
|
@ -69,7 +76,7 @@ SEVERITY_COLORS = {
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
def format_event(event: Event, span_name: str) -> Optional[str]:
|
def format_event_text(event: Event, span_name: str) -> Optional[str]:
|
||||||
timestamp = event.timestamp.strftime("%H:%M:%S.%f")[:-3]
|
timestamp = event.timestamp.strftime("%H:%M:%S.%f")[:-3]
|
||||||
span = ""
|
span = ""
|
||||||
if span_name:
|
if span_name:
|
||||||
|
@ -87,3 +94,23 @@ def format_event(event: Event, span_name: str) -> Optional[str]:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
return f"Unknown event type: {event}"
|
return f"Unknown event type: {event}"
|
||||||
|
|
||||||
|
|
||||||
|
def format_event_json(event: Event, span_name: str) -> Optional[str]:
|
||||||
|
base_data = {
|
||||||
|
"timestamp": event.timestamp.isoformat(),
|
||||||
|
"trace_id": event.trace_id,
|
||||||
|
"span_id": event.span_id,
|
||||||
|
"span_name": span_name,
|
||||||
|
}
|
||||||
|
|
||||||
|
if isinstance(event, UnstructuredLogEvent):
|
||||||
|
base_data.update(
|
||||||
|
{"type": "log", "severity": event.severity.name, "message": event.message}
|
||||||
|
)
|
||||||
|
return json.dumps(base_data)
|
||||||
|
|
||||||
|
elif isinstance(event, StructuredLogEvent):
|
||||||
|
return None
|
||||||
|
|
||||||
|
return json.dumps({"error": f"Unknown event type: {event}"})
|
||||||
|
|
|
@ -4,16 +4,16 @@
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
import logging
|
||||||
from typing import Any, Dict, List
|
from typing import Any, Dict, List
|
||||||
|
|
||||||
from llama_models.llama3.api.datatypes import interleaved_text_media_as_str, Message
|
from llama_models.llama3.api.datatypes import interleaved_text_media_as_str, Message
|
||||||
from termcolor import cprint
|
|
||||||
|
|
||||||
from .config import CodeScannerConfig
|
from .config import CodeScannerConfig
|
||||||
|
|
||||||
from llama_stack.apis.safety import * # noqa: F403
|
from llama_stack.apis.safety import * # noqa: F403
|
||||||
|
|
||||||
|
log = logging.getLogger(__name__)
|
||||||
ALLOWED_CODE_SCANNER_MODEL_IDS = [
|
ALLOWED_CODE_SCANNER_MODEL_IDS = [
|
||||||
"CodeScanner",
|
"CodeScanner",
|
||||||
"CodeShield",
|
"CodeShield",
|
||||||
|
@ -49,7 +49,7 @@ class MetaReferenceCodeScannerSafetyImpl(Safety):
|
||||||
from codeshield.cs import CodeShield
|
from codeshield.cs import CodeShield
|
||||||
|
|
||||||
text = "\n".join([interleaved_text_media_as_str(m.content) for m in messages])
|
text = "\n".join([interleaved_text_media_as_str(m.content) for m in messages])
|
||||||
cprint(f"Running CodeScannerShield on {text[50:]}", color="magenta")
|
log.info(f"Running CodeScannerShield on {text[50:]}")
|
||||||
result = await CodeShield.scan_code(text)
|
result = await CodeShield.scan_code(text)
|
||||||
|
|
||||||
violation = None
|
violation = None
|
||||||
|
|
|
@ -4,10 +4,10 @@
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
import logging
|
||||||
from typing import Any, Dict, List
|
from typing import Any, Dict, List
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from termcolor import cprint
|
|
||||||
|
|
||||||
from transformers import AutoModelForSequenceClassification, AutoTokenizer
|
from transformers import AutoModelForSequenceClassification, AutoTokenizer
|
||||||
|
|
||||||
|
@ -20,6 +20,7 @@ from llama_stack.providers.datatypes import ShieldsProtocolPrivate
|
||||||
|
|
||||||
from .config import PromptGuardConfig, PromptGuardType
|
from .config import PromptGuardConfig, PromptGuardType
|
||||||
|
|
||||||
|
log = logging.getLogger(__name__)
|
||||||
|
|
||||||
PROMPT_GUARD_MODEL = "Prompt-Guard-86M"
|
PROMPT_GUARD_MODEL = "Prompt-Guard-86M"
|
||||||
|
|
||||||
|
@ -93,9 +94,8 @@ class PromptGuardShield:
|
||||||
probabilities = torch.softmax(logits / self.temperature, dim=-1)
|
probabilities = torch.softmax(logits / self.temperature, dim=-1)
|
||||||
score_embedded = probabilities[0, 1].item()
|
score_embedded = probabilities[0, 1].item()
|
||||||
score_malicious = probabilities[0, 2].item()
|
score_malicious = probabilities[0, 2].item()
|
||||||
cprint(
|
log.info(
|
||||||
f"Ran PromptGuardShield and got Scores: Embedded: {score_embedded}, Malicious: {score_malicious}",
|
f"Ran PromptGuardShield and got Scores: Embedded: {score_embedded}, Malicious: {score_malicious}",
|
||||||
color="magenta",
|
|
||||||
)
|
)
|
||||||
|
|
||||||
violation = None
|
violation = None
|
||||||
|
|
|
@ -4,6 +4,7 @@
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
import logging
|
||||||
from typing import AsyncGenerator
|
from typing import AsyncGenerator
|
||||||
|
|
||||||
import httpx
|
import httpx
|
||||||
|
@ -39,6 +40,7 @@ from llama_stack.providers.utils.inference.prompt_adapter import (
|
||||||
request_has_media,
|
request_has_media,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
log = logging.getLogger(__name__)
|
||||||
|
|
||||||
model_aliases = [
|
model_aliases = [
|
||||||
build_model_alias(
|
build_model_alias(
|
||||||
|
@ -105,7 +107,7 @@ class OllamaInferenceAdapter(Inference, ModelsProtocolPrivate):
|
||||||
return AsyncClient(host=self.url)
|
return AsyncClient(host=self.url)
|
||||||
|
|
||||||
async def initialize(self) -> None:
|
async def initialize(self) -> None:
|
||||||
print(f"checking connectivity to Ollama at `{self.url}`...")
|
log.info(f"checking connectivity to Ollama at `{self.url}`...")
|
||||||
try:
|
try:
|
||||||
await self.client.ps()
|
await self.client.ps()
|
||||||
except httpx.ConnectError as e:
|
except httpx.ConnectError as e:
|
||||||
|
|
|
@ -34,7 +34,7 @@ from llama_stack.providers.utils.inference.prompt_adapter import (
|
||||||
|
|
||||||
from .config import InferenceAPIImplConfig, InferenceEndpointImplConfig, TGIImplConfig
|
from .config import InferenceAPIImplConfig, InferenceEndpointImplConfig, TGIImplConfig
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
log = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class _HfAdapter(Inference, ModelsProtocolPrivate):
|
class _HfAdapter(Inference, ModelsProtocolPrivate):
|
||||||
|
@ -264,7 +264,7 @@ class _HfAdapter(Inference, ModelsProtocolPrivate):
|
||||||
|
|
||||||
class TGIAdapter(_HfAdapter):
|
class TGIAdapter(_HfAdapter):
|
||||||
async def initialize(self, config: TGIImplConfig) -> None:
|
async def initialize(self, config: TGIImplConfig) -> None:
|
||||||
print(f"Initializing TGI client with url={config.url}")
|
log.info(f"Initializing TGI client with url={config.url}")
|
||||||
self.client = AsyncInferenceClient(model=config.url, token=config.api_token)
|
self.client = AsyncInferenceClient(model=config.url, token=config.api_token)
|
||||||
endpoint_info = await self.client.get_endpoint_info()
|
endpoint_info = await self.client.get_endpoint_info()
|
||||||
self.max_tokens = endpoint_info["max_total_tokens"]
|
self.max_tokens = endpoint_info["max_total_tokens"]
|
||||||
|
|
|
@ -3,6 +3,8 @@
|
||||||
#
|
#
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
import logging
|
||||||
from typing import AsyncGenerator
|
from typing import AsyncGenerator
|
||||||
|
|
||||||
from llama_models.llama3.api.chat_format import ChatFormat
|
from llama_models.llama3.api.chat_format import ChatFormat
|
||||||
|
@ -34,6 +36,9 @@ from llama_stack.providers.utils.inference.prompt_adapter import (
|
||||||
from .config import VLLMInferenceAdapterConfig
|
from .config import VLLMInferenceAdapterConfig
|
||||||
|
|
||||||
|
|
||||||
|
log = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
def build_model_aliases():
|
def build_model_aliases():
|
||||||
return [
|
return [
|
||||||
build_model_alias(
|
build_model_alias(
|
||||||
|
@ -53,7 +58,7 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate):
|
||||||
self.client = None
|
self.client = None
|
||||||
|
|
||||||
async def initialize(self) -> None:
|
async def initialize(self) -> None:
|
||||||
print(f"Initializing VLLM client with base_url={self.config.url}")
|
log.info(f"Initializing VLLM client with base_url={self.config.url}")
|
||||||
self.client = OpenAI(base_url=self.config.url, api_key=self.config.api_token)
|
self.client = OpenAI(base_url=self.config.url, api_key=self.config.api_token)
|
||||||
|
|
||||||
async def shutdown(self) -> None:
|
async def shutdown(self) -> None:
|
||||||
|
|
|
@ -5,6 +5,7 @@
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
import json
|
import json
|
||||||
|
import logging
|
||||||
from typing import List
|
from typing import List
|
||||||
from urllib.parse import urlparse
|
from urllib.parse import urlparse
|
||||||
|
|
||||||
|
@ -21,6 +22,8 @@ from llama_stack.providers.utils.memory.vector_store import (
|
||||||
EmbeddingIndex,
|
EmbeddingIndex,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
log = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class ChromaIndex(EmbeddingIndex):
|
class ChromaIndex(EmbeddingIndex):
|
||||||
def __init__(self, client: chromadb.AsyncHttpClient, collection):
|
def __init__(self, client: chromadb.AsyncHttpClient, collection):
|
||||||
|
@ -56,10 +59,7 @@ class ChromaIndex(EmbeddingIndex):
|
||||||
doc = json.loads(doc)
|
doc = json.loads(doc)
|
||||||
chunk = Chunk(**doc)
|
chunk = Chunk(**doc)
|
||||||
except Exception:
|
except Exception:
|
||||||
import traceback
|
log.exception(f"Failed to parse document: {doc}")
|
||||||
|
|
||||||
traceback.print_exc()
|
|
||||||
print(f"Failed to parse document: {doc}")
|
|
||||||
continue
|
continue
|
||||||
|
|
||||||
chunks.append(chunk)
|
chunks.append(chunk)
|
||||||
|
@ -73,7 +73,7 @@ class ChromaIndex(EmbeddingIndex):
|
||||||
|
|
||||||
class ChromaMemoryAdapter(Memory, MemoryBanksProtocolPrivate):
|
class ChromaMemoryAdapter(Memory, MemoryBanksProtocolPrivate):
|
||||||
def __init__(self, url: str) -> None:
|
def __init__(self, url: str) -> None:
|
||||||
print(f"Initializing ChromaMemoryAdapter with url: {url}")
|
log.info(f"Initializing ChromaMemoryAdapter with url: {url}")
|
||||||
url = url.rstrip("/")
|
url = url.rstrip("/")
|
||||||
parsed = urlparse(url)
|
parsed = urlparse(url)
|
||||||
|
|
||||||
|
@ -88,12 +88,10 @@ class ChromaMemoryAdapter(Memory, MemoryBanksProtocolPrivate):
|
||||||
|
|
||||||
async def initialize(self) -> None:
|
async def initialize(self) -> None:
|
||||||
try:
|
try:
|
||||||
print(f"Connecting to Chroma server at: {self.host}:{self.port}")
|
log.info(f"Connecting to Chroma server at: {self.host}:{self.port}")
|
||||||
self.client = await chromadb.AsyncHttpClient(host=self.host, port=self.port)
|
self.client = await chromadb.AsyncHttpClient(host=self.host, port=self.port)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
import traceback
|
log.exception("Could not connect to Chroma server")
|
||||||
|
|
||||||
traceback.print_exc()
|
|
||||||
raise RuntimeError("Could not connect to Chroma server") from e
|
raise RuntimeError("Could not connect to Chroma server") from e
|
||||||
|
|
||||||
async def shutdown(self) -> None:
|
async def shutdown(self) -> None:
|
||||||
|
@ -123,10 +121,7 @@ class ChromaMemoryAdapter(Memory, MemoryBanksProtocolPrivate):
|
||||||
data = json.loads(collection.metadata["bank"])
|
data = json.loads(collection.metadata["bank"])
|
||||||
bank = parse_obj_as(VectorMemoryBank, data)
|
bank = parse_obj_as(VectorMemoryBank, data)
|
||||||
except Exception:
|
except Exception:
|
||||||
import traceback
|
log.exception(f"Failed to parse bank: {collection.metadata}")
|
||||||
|
|
||||||
traceback.print_exc()
|
|
||||||
print(f"Failed to parse bank: {collection.metadata}")
|
|
||||||
continue
|
continue
|
||||||
|
|
||||||
index = BankWithIndex(
|
index = BankWithIndex(
|
||||||
|
|
|
@ -4,6 +4,7 @@
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
import logging
|
||||||
from typing import List, Tuple
|
from typing import List, Tuple
|
||||||
|
|
||||||
import psycopg2
|
import psycopg2
|
||||||
|
@ -24,6 +25,8 @@ from llama_stack.providers.utils.memory.vector_store import (
|
||||||
|
|
||||||
from .config import PGVectorConfig
|
from .config import PGVectorConfig
|
||||||
|
|
||||||
|
log = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
def check_extension_version(cur):
|
def check_extension_version(cur):
|
||||||
cur.execute("SELECT extversion FROM pg_extension WHERE extname = 'vector'")
|
cur.execute("SELECT extversion FROM pg_extension WHERE extname = 'vector'")
|
||||||
|
@ -124,7 +127,7 @@ class PGVectorMemoryAdapter(Memory, MemoryBanksProtocolPrivate):
|
||||||
self.cache = {}
|
self.cache = {}
|
||||||
|
|
||||||
async def initialize(self) -> None:
|
async def initialize(self) -> None:
|
||||||
print(f"Initializing PGVector memory adapter with config: {self.config}")
|
log.info(f"Initializing PGVector memory adapter with config: {self.config}")
|
||||||
try:
|
try:
|
||||||
self.conn = psycopg2.connect(
|
self.conn = psycopg2.connect(
|
||||||
host=self.config.host,
|
host=self.config.host,
|
||||||
|
@ -138,7 +141,7 @@ class PGVectorMemoryAdapter(Memory, MemoryBanksProtocolPrivate):
|
||||||
|
|
||||||
version = check_extension_version(self.cursor)
|
version = check_extension_version(self.cursor)
|
||||||
if version:
|
if version:
|
||||||
print(f"Vector extension version: {version}")
|
log.info(f"Vector extension version: {version}")
|
||||||
else:
|
else:
|
||||||
raise RuntimeError("Vector extension is not installed.")
|
raise RuntimeError("Vector extension is not installed.")
|
||||||
|
|
||||||
|
@ -151,9 +154,7 @@ class PGVectorMemoryAdapter(Memory, MemoryBanksProtocolPrivate):
|
||||||
"""
|
"""
|
||||||
)
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
import traceback
|
log.exception("Could not connect to PGVector database server")
|
||||||
|
|
||||||
traceback.print_exc()
|
|
||||||
raise RuntimeError("Could not connect to PGVector database server") from e
|
raise RuntimeError("Could not connect to PGVector database server") from e
|
||||||
|
|
||||||
async def shutdown(self) -> None:
|
async def shutdown(self) -> None:
|
||||||
|
|
|
@ -4,7 +4,7 @@
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
import traceback
|
import logging
|
||||||
import uuid
|
import uuid
|
||||||
from typing import Any, Dict, List
|
from typing import Any, Dict, List
|
||||||
|
|
||||||
|
@ -23,6 +23,7 @@ from llama_stack.providers.utils.memory.vector_store import (
|
||||||
EmbeddingIndex,
|
EmbeddingIndex,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
log = logging.getLogger(__name__)
|
||||||
CHUNK_ID_KEY = "_chunk_id"
|
CHUNK_ID_KEY = "_chunk_id"
|
||||||
|
|
||||||
|
|
||||||
|
@ -90,7 +91,7 @@ class QdrantIndex(EmbeddingIndex):
|
||||||
try:
|
try:
|
||||||
chunk = Chunk(**point.payload["chunk_content"])
|
chunk = Chunk(**point.payload["chunk_content"])
|
||||||
except Exception:
|
except Exception:
|
||||||
traceback.print_exc()
|
log.exception("Failed to parse chunk")
|
||||||
continue
|
continue
|
||||||
|
|
||||||
chunks.append(chunk)
|
chunks.append(chunk)
|
||||||
|
|
|
@ -4,6 +4,7 @@
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
import json
|
import json
|
||||||
|
import logging
|
||||||
|
|
||||||
from typing import Any, Dict, List, Optional
|
from typing import Any, Dict, List, Optional
|
||||||
|
|
||||||
|
@ -22,6 +23,8 @@ from llama_stack.providers.utils.memory.vector_store import (
|
||||||
|
|
||||||
from .config import WeaviateConfig, WeaviateRequestProviderData
|
from .config import WeaviateConfig, WeaviateRequestProviderData
|
||||||
|
|
||||||
|
log = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class WeaviateIndex(EmbeddingIndex):
|
class WeaviateIndex(EmbeddingIndex):
|
||||||
def __init__(self, client: weaviate.Client, collection_name: str):
|
def __init__(self, client: weaviate.Client, collection_name: str):
|
||||||
|
@ -69,10 +72,7 @@ class WeaviateIndex(EmbeddingIndex):
|
||||||
chunk_dict = json.loads(chunk_json)
|
chunk_dict = json.loads(chunk_json)
|
||||||
chunk = Chunk(**chunk_dict)
|
chunk = Chunk(**chunk_dict)
|
||||||
except Exception:
|
except Exception:
|
||||||
import traceback
|
log.exception(f"Failed to parse document: {chunk_json}")
|
||||||
|
|
||||||
traceback.print_exc()
|
|
||||||
print(f"Failed to parse document: {chunk_json}")
|
|
||||||
continue
|
continue
|
||||||
|
|
||||||
chunks.append(chunk)
|
chunks.append(chunk)
|
||||||
|
|
|
@ -7,14 +7,13 @@
|
||||||
import base64
|
import base64
|
||||||
import io
|
import io
|
||||||
import json
|
import json
|
||||||
|
import logging
|
||||||
from typing import Tuple
|
from typing import Tuple
|
||||||
|
|
||||||
import httpx
|
import httpx
|
||||||
|
|
||||||
from llama_models.llama3.api.chat_format import ChatFormat
|
from llama_models.llama3.api.chat_format import ChatFormat
|
||||||
from PIL import Image as PIL_Image
|
from PIL import Image as PIL_Image
|
||||||
from termcolor import cprint
|
|
||||||
|
|
||||||
from llama_models.llama3.api.datatypes import * # noqa: F403
|
from llama_models.llama3.api.datatypes import * # noqa: F403
|
||||||
from llama_stack.apis.inference import * # noqa: F403
|
from llama_stack.apis.inference import * # noqa: F403
|
||||||
from llama_models.datatypes import ModelFamily
|
from llama_models.datatypes import ModelFamily
|
||||||
|
@ -29,6 +28,8 @@ from llama_models.sku_list import resolve_model
|
||||||
|
|
||||||
from llama_stack.providers.utils.inference import supported_inference_models
|
from llama_stack.providers.utils.inference import supported_inference_models
|
||||||
|
|
||||||
|
log = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
def content_has_media(content: InterleavedTextMedia):
|
def content_has_media(content: InterleavedTextMedia):
|
||||||
def _has_media_content(c):
|
def _has_media_content(c):
|
||||||
|
@ -175,13 +176,13 @@ def chat_completion_request_to_messages(
|
||||||
"""
|
"""
|
||||||
model = resolve_model(llama_model)
|
model = resolve_model(llama_model)
|
||||||
if model is None:
|
if model is None:
|
||||||
cprint(f"Could not resolve model {llama_model}", color="red")
|
log.error(f"Could not resolve model {llama_model}")
|
||||||
return request.messages
|
return request.messages
|
||||||
|
|
||||||
allowed_models = supported_inference_models()
|
allowed_models = supported_inference_models()
|
||||||
descriptors = [m.descriptor() for m in allowed_models]
|
descriptors = [m.descriptor() for m in allowed_models]
|
||||||
if model.descriptor() not in descriptors:
|
if model.descriptor() not in descriptors:
|
||||||
cprint(f"Unsupported inference model? {model.descriptor()}", color="red")
|
log.error(f"Unsupported inference model? {model.descriptor()}")
|
||||||
return request.messages
|
return request.messages
|
||||||
|
|
||||||
if model.model_family == ModelFamily.llama3_1 or (
|
if model.model_family == ModelFamily.llama3_1 or (
|
||||||
|
|
|
@ -4,6 +4,7 @@
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
import logging
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from typing import List, Optional
|
from typing import List, Optional
|
||||||
|
|
||||||
|
@ -13,6 +14,8 @@ from psycopg2.extras import DictCursor
|
||||||
from ..api import KVStore
|
from ..api import KVStore
|
||||||
from ..config import PostgresKVStoreConfig
|
from ..config import PostgresKVStoreConfig
|
||||||
|
|
||||||
|
log = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class PostgresKVStoreImpl(KVStore):
|
class PostgresKVStoreImpl(KVStore):
|
||||||
def __init__(self, config: PostgresKVStoreConfig):
|
def __init__(self, config: PostgresKVStoreConfig):
|
||||||
|
@ -43,9 +46,8 @@ class PostgresKVStoreImpl(KVStore):
|
||||||
"""
|
"""
|
||||||
)
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
import traceback
|
|
||||||
|
|
||||||
traceback.print_exc()
|
log.exception("Could not connect to PostgreSQL database server")
|
||||||
raise RuntimeError("Could not connect to PostgreSQL database server") from e
|
raise RuntimeError("Could not connect to PostgreSQL database server") from e
|
||||||
|
|
||||||
def _namespaced_key(self, key: str) -> str:
|
def _namespaced_key(self, key: str) -> str:
|
||||||
|
|
|
@ -5,6 +5,7 @@
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
import base64
|
import base64
|
||||||
import io
|
import io
|
||||||
|
import logging
|
||||||
import re
|
import re
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
|
@ -16,13 +17,14 @@ import httpx
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from numpy.typing import NDArray
|
from numpy.typing import NDArray
|
||||||
from pypdf import PdfReader
|
from pypdf import PdfReader
|
||||||
from termcolor import cprint
|
|
||||||
|
|
||||||
from llama_models.llama3.api.datatypes import * # noqa: F403
|
from llama_models.llama3.api.datatypes import * # noqa: F403
|
||||||
from llama_models.llama3.api.tokenizer import Tokenizer
|
from llama_models.llama3.api.tokenizer import Tokenizer
|
||||||
|
|
||||||
from llama_stack.apis.memory import * # noqa: F403
|
from llama_stack.apis.memory import * # noqa: F403
|
||||||
|
|
||||||
|
log = logging.getLogger(__name__)
|
||||||
|
|
||||||
ALL_MINILM_L6_V2_DIMENSION = 384
|
ALL_MINILM_L6_V2_DIMENSION = 384
|
||||||
|
|
||||||
EMBEDDING_MODELS = {}
|
EMBEDDING_MODELS = {}
|
||||||
|
@ -35,7 +37,7 @@ def get_embedding_model(model: str) -> "SentenceTransformer":
|
||||||
if loaded_model is not None:
|
if loaded_model is not None:
|
||||||
return loaded_model
|
return loaded_model
|
||||||
|
|
||||||
print(f"Loading sentence transformer for {model}...")
|
log.info(f"Loading sentence transformer for {model}...")
|
||||||
from sentence_transformers import SentenceTransformer
|
from sentence_transformers import SentenceTransformer
|
||||||
|
|
||||||
loaded_model = SentenceTransformer(model)
|
loaded_model = SentenceTransformer(model)
|
||||||
|
@ -92,7 +94,7 @@ def content_from_data(data_url: str) -> str:
|
||||||
return "\n".join([page.extract_text() for page in pdf_reader.pages])
|
return "\n".join([page.extract_text() for page in pdf_reader.pages])
|
||||||
|
|
||||||
else:
|
else:
|
||||||
cprint("Could not extract content from data_url properly.", color="red")
|
log.error("Could not extract content from data_url properly.")
|
||||||
return ""
|
return ""
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -17,6 +17,8 @@ from typing import Any, Callable, Dict, List
|
||||||
|
|
||||||
from llama_stack.apis.telemetry import * # noqa: F403
|
from llama_stack.apis.telemetry import * # noqa: F403
|
||||||
|
|
||||||
|
log = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
def generate_short_uuid(len: int = 12):
|
def generate_short_uuid(len: int = 12):
|
||||||
full_uuid = uuid.uuid4()
|
full_uuid = uuid.uuid4()
|
||||||
|
@ -40,7 +42,7 @@ class BackgroundLogger:
|
||||||
try:
|
try:
|
||||||
self.log_queue.put_nowait(event)
|
self.log_queue.put_nowait(event)
|
||||||
except queue.Full:
|
except queue.Full:
|
||||||
print("Log queue is full, dropping event")
|
log.error("Log queue is full, dropping event")
|
||||||
|
|
||||||
def _process_logs(self):
|
def _process_logs(self):
|
||||||
while True:
|
while True:
|
||||||
|
@ -125,7 +127,7 @@ async def start_trace(name: str, attributes: Dict[str, Any] = None):
|
||||||
global CURRENT_TRACE_CONTEXT, BACKGROUND_LOGGER
|
global CURRENT_TRACE_CONTEXT, BACKGROUND_LOGGER
|
||||||
|
|
||||||
if BACKGROUND_LOGGER is None:
|
if BACKGROUND_LOGGER is None:
|
||||||
print("No Telemetry implementation set. Skipping trace initialization...")
|
log.info("No Telemetry implementation set. Skipping trace initialization...")
|
||||||
return
|
return
|
||||||
|
|
||||||
trace_id = generate_short_uuid()
|
trace_id = generate_short_uuid()
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue