forked from phoenix-oss/llama-stack-mirror
use logging instead of prints (#499)
# What does this PR do? This PR moves all print statements to use logging. Things changed: - Had to add `await start_trace("sse_generator")` to server.py to actually get tracing working. else was not seeing any logs - If no telemetry provider is provided in the run.yaml, we will write to stdout - by default, the logs are going to be in JSON, but we expose an option to configure to output in a human readable way.
This commit is contained in:
parent
4e1105e563
commit
6395dadc2b
36 changed files with 234 additions and 163 deletions
|
@ -6,6 +6,7 @@
|
|||
|
||||
import asyncio
|
||||
import copy
|
||||
import logging
|
||||
import os
|
||||
import re
|
||||
import secrets
|
||||
|
@ -19,7 +20,6 @@ from urllib.parse import urlparse
|
|||
|
||||
import httpx
|
||||
|
||||
from termcolor import cprint
|
||||
|
||||
from llama_stack.apis.agents import * # noqa: F403
|
||||
from llama_stack.apis.inference import * # noqa: F403
|
||||
|
@ -43,6 +43,8 @@ from .tools.builtin import (
|
|||
)
|
||||
from .tools.safety import SafeTool
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def make_random_string(length: int = 8):
|
||||
return "".join(
|
||||
|
@ -137,7 +139,6 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
stop_reason=StopReason.end_of_turn,
|
||||
)
|
||||
)
|
||||
# print_dialog(messages)
|
||||
return messages
|
||||
|
||||
async def create_session(self, name: str) -> str:
|
||||
|
@ -185,10 +186,8 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
stream=request.stream,
|
||||
):
|
||||
if isinstance(chunk, CompletionMessage):
|
||||
cprint(
|
||||
log.info(
|
||||
f"{chunk.role.capitalize()}: {chunk.content}",
|
||||
"white",
|
||||
attrs=["bold"],
|
||||
)
|
||||
output_message = chunk
|
||||
continue
|
||||
|
@ -407,7 +406,7 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
msg_str = f"{str(msg)[:500]}...<more>...{str(msg)[-500:]}"
|
||||
else:
|
||||
msg_str = str(msg)
|
||||
cprint(f"{msg_str}", color=color)
|
||||
log.info(f"{msg_str}")
|
||||
|
||||
step_id = str(uuid.uuid4())
|
||||
yield AgentTurnResponseStreamChunk(
|
||||
|
@ -506,12 +505,12 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
)
|
||||
|
||||
if n_iter >= self.agent_config.max_infer_iters:
|
||||
cprint("Done with MAX iterations, exiting.")
|
||||
log.info("Done with MAX iterations, exiting.")
|
||||
yield message
|
||||
break
|
||||
|
||||
if stop_reason == StopReason.out_of_tokens:
|
||||
cprint("Out of token budget, exiting.")
|
||||
log.info("Out of token budget, exiting.")
|
||||
yield message
|
||||
break
|
||||
|
||||
|
@ -525,10 +524,10 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
message.content = [message.content] + attachments
|
||||
yield message
|
||||
else:
|
||||
cprint(f"Partial message: {str(message)}", color="green")
|
||||
log.info(f"Partial message: {str(message)}", color="green")
|
||||
input_messages = input_messages + [message]
|
||||
else:
|
||||
cprint(f"{str(message)}", color="green")
|
||||
log.info(f"{str(message)}", color="green")
|
||||
try:
|
||||
tool_call = message.tool_calls[0]
|
||||
|
||||
|
@ -740,9 +739,8 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
for c in chunks[: memory.max_chunks]:
|
||||
tokens += c.token_count
|
||||
if tokens > memory.max_tokens_in_context:
|
||||
cprint(
|
||||
log.error(
|
||||
f"Using {len(picked)} chunks; reached max tokens in context: {tokens}",
|
||||
"red",
|
||||
)
|
||||
break
|
||||
picked.append(f"id:{c.document_id}; content:{c.content}")
|
||||
|
@ -786,7 +784,7 @@ async def attachment_message(tempdir: str, urls: List[URL]) -> ToolResponseMessa
|
|||
path = urlparse(uri).path
|
||||
basename = os.path.basename(path)
|
||||
filepath = f"{tempdir}/{make_random_string() + basename}"
|
||||
print(f"Downloading {url} -> {filepath}")
|
||||
log.info(f"Downloading {url} -> {filepath}")
|
||||
|
||||
async with httpx.AsyncClient() as client:
|
||||
r = await client.get(uri)
|
||||
|
@ -826,20 +824,3 @@ async def execute_tool_call_maybe(
|
|||
tool = tools_dict[name]
|
||||
result_messages = await tool.run(messages)
|
||||
return result_messages
|
||||
|
||||
|
||||
def print_dialog(messages: List[Message]):
|
||||
for i, m in enumerate(messages):
|
||||
if m.role == Role.user.value:
|
||||
color = "red"
|
||||
elif m.role == Role.assistant.value:
|
||||
color = "white"
|
||||
elif m.role == Role.ipython.value:
|
||||
color = "yellow"
|
||||
elif m.role == Role.system.value:
|
||||
color = "green"
|
||||
else:
|
||||
color = "white"
|
||||
|
||||
s = str(m)
|
||||
cprint(f"{i} ::: {s[:100]}...", color=color)
|
||||
|
|
|
@ -5,7 +5,7 @@
|
|||
# the root directory of this source tree.
|
||||
|
||||
import json
|
||||
|
||||
import logging
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
|
||||
|
@ -15,6 +15,8 @@ from pydantic import BaseModel
|
|||
|
||||
from llama_stack.providers.utils.kvstore import KVStore
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class AgentSessionInfo(BaseModel):
|
||||
session_id: str
|
||||
|
@ -78,7 +80,7 @@ class AgentPersistence:
|
|||
turn = Turn(**json.loads(value))
|
||||
turns.append(turn)
|
||||
except Exception as e:
|
||||
print(f"Error parsing turn: {e}")
|
||||
log.error(f"Error parsing turn: {e}")
|
||||
continue
|
||||
turns.sort(key=lambda x: (x.completed_at or datetime.min))
|
||||
return turns
|
||||
|
|
|
@ -10,8 +10,6 @@ from jinja2 import Template
|
|||
from llama_models.llama3.api import * # noqa: F403
|
||||
|
||||
|
||||
from termcolor import cprint # noqa: F401
|
||||
|
||||
from llama_stack.apis.agents import (
|
||||
DefaultMemoryQueryGeneratorConfig,
|
||||
LLMMemoryQueryGeneratorConfig,
|
||||
|
@ -36,7 +34,6 @@ async def generate_rag_query(
|
|||
query = await llm_rag_query_generator(config, messages, **kwargs)
|
||||
else:
|
||||
raise NotImplementedError(f"Unsupported memory query generator {config.type}")
|
||||
# cprint(f"Generated query >>>: {query}", color="green")
|
||||
return query
|
||||
|
||||
|
||||
|
|
|
@ -5,14 +5,16 @@
|
|||
# the root directory of this source tree.
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
|
||||
from typing import List
|
||||
|
||||
from llama_models.llama3.api.datatypes import Message
|
||||
from termcolor import cprint
|
||||
|
||||
from llama_stack.apis.safety import * # noqa: F403
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class SafetyException(Exception): # noqa: N818
|
||||
def __init__(self, violation: SafetyViolation):
|
||||
|
@ -51,7 +53,4 @@ class ShieldRunnerMixin:
|
|||
if violation.violation_level == ViolationLevel.ERROR:
|
||||
raise SafetyException(violation)
|
||||
elif violation.violation_level == ViolationLevel.WARN:
|
||||
cprint(
|
||||
f"[Warn]{identifier} raised a warning",
|
||||
color="red",
|
||||
)
|
||||
log.warning(f"[Warn]{identifier} raised a warning")
|
||||
|
|
|
@ -5,6 +5,7 @@
|
|||
# the root directory of this source tree.
|
||||
|
||||
import json
|
||||
import logging
|
||||
import re
|
||||
import tempfile
|
||||
|
||||
|
@ -12,7 +13,6 @@ from abc import abstractmethod
|
|||
from typing import List, Optional
|
||||
|
||||
import requests
|
||||
from termcolor import cprint
|
||||
|
||||
from .ipython_tool.code_execution import (
|
||||
CodeExecutionContext,
|
||||
|
@ -27,6 +27,9 @@ from llama_stack.apis.agents import * # noqa: F403
|
|||
from .base import BaseTool
|
||||
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def interpret_content_as_attachment(content: str) -> Optional[Attachment]:
|
||||
match = re.search(TOOLS_ATTACHMENT_KEY_REGEX, content)
|
||||
if match:
|
||||
|
@ -383,7 +386,7 @@ class CodeInterpreterTool(BaseTool):
|
|||
if res_out != "":
|
||||
pieces.extend([f"[{out_type}]", res_out, f"[/{out_type}]"])
|
||||
if out_type == "stderr":
|
||||
cprint(f"ipython tool error: ↓\n{res_out}", color="red")
|
||||
log.error(f"ipython tool error: ↓\n{res_out}")
|
||||
|
||||
message = ToolResponseMessage(
|
||||
call_id=tool_call.call_id,
|
||||
|
|
|
@ -11,6 +11,7 @@ A custom Matplotlib backend that overrides the show method to return image bytes
|
|||
import base64
|
||||
import io
|
||||
import json as _json
|
||||
import logging
|
||||
|
||||
import matplotlib
|
||||
from matplotlib.backend_bases import FigureManagerBase
|
||||
|
@ -18,6 +19,8 @@ from matplotlib.backend_bases import FigureManagerBase
|
|||
# Import necessary components from Matplotlib
|
||||
from matplotlib.backends.backend_agg import FigureCanvasAgg
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class CustomFigureCanvas(FigureCanvasAgg):
|
||||
def show(self):
|
||||
|
@ -80,7 +83,7 @@ def show():
|
|||
)
|
||||
req_con.send_bytes(_json_dump.encode("utf-8"))
|
||||
resp = _json.loads(resp_con.recv_bytes().decode("utf-8"))
|
||||
print(resp)
|
||||
log.info(resp)
|
||||
|
||||
|
||||
FigureCanvas = CustomFigureCanvas
|
||||
|
|
|
@ -8,6 +8,7 @@
|
|||
# This software may be used and distributed in accordance with the terms of the Llama 3 Community License Agreement.
|
||||
|
||||
import json
|
||||
import logging
|
||||
import math
|
||||
import os
|
||||
import sys
|
||||
|
@ -31,7 +32,6 @@ from llama_models.llama3.reference_impl.multimodal.model import (
|
|||
)
|
||||
from llama_models.sku_list import resolve_model
|
||||
from pydantic import BaseModel
|
||||
from termcolor import cprint
|
||||
|
||||
from llama_stack.apis.inference import * # noqa: F403
|
||||
|
||||
|
@ -50,6 +50,8 @@ from .config import (
|
|||
MetaReferenceQuantizedInferenceConfig,
|
||||
)
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def model_checkpoint_dir(model) -> str:
|
||||
checkpoint_dir = Path(model_local_dir(model.descriptor()))
|
||||
|
@ -185,7 +187,7 @@ class Llama:
|
|||
model = Transformer(model_args)
|
||||
model.load_state_dict(state_dict, strict=False)
|
||||
|
||||
print(f"Loaded in {time.time() - start_time:.2f} seconds")
|
||||
log.info(f"Loaded in {time.time() - start_time:.2f} seconds")
|
||||
return Llama(model, tokenizer, model_args, llama_model)
|
||||
|
||||
def __init__(
|
||||
|
@ -221,7 +223,7 @@ class Llama:
|
|||
self.formatter.vision_token if t == 128256 else t
|
||||
for t in model_input.tokens
|
||||
]
|
||||
cprint("Input to model -> " + self.tokenizer.decode(input_tokens), "red")
|
||||
log.info("Input to model -> " + self.tokenizer.decode(input_tokens))
|
||||
prompt_tokens = [model_input.tokens]
|
||||
|
||||
bsz = 1
|
||||
|
@ -231,9 +233,7 @@ class Llama:
|
|||
max_prompt_len = max(len(t) for t in prompt_tokens)
|
||||
|
||||
if max_prompt_len >= params.max_seq_len:
|
||||
cprint(
|
||||
f"Out of token budget {max_prompt_len} vs {params.max_seq_len}", "red"
|
||||
)
|
||||
log.error(f"Out of token budget {max_prompt_len} vs {params.max_seq_len}")
|
||||
return
|
||||
|
||||
total_len = min(max_gen_len + max_prompt_len, params.max_seq_len)
|
||||
|
|
|
@ -5,6 +5,7 @@
|
|||
# the root directory of this source tree.
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
|
||||
from typing import AsyncGenerator, List
|
||||
|
||||
|
@ -25,6 +26,7 @@ from .config import MetaReferenceInferenceConfig
|
|||
from .generation import Llama
|
||||
from .model_parallel import LlamaModelParallelGenerator
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
# there's a single model parallel process running serving the model. for now,
|
||||
# we don't support multiple concurrent requests to this process.
|
||||
SEMAPHORE = asyncio.Semaphore(1)
|
||||
|
@ -49,7 +51,7 @@ class MetaReferenceInferenceImpl(Inference, ModelRegistryHelper, ModelsProtocolP
|
|||
# verify that the checkpoint actually is for this model lol
|
||||
|
||||
async def initialize(self) -> None:
|
||||
print(f"Loading model `{self.model.descriptor()}`")
|
||||
log.info(f"Loading model `{self.model.descriptor()}`")
|
||||
if self.config.create_distributed_process_group:
|
||||
self.generator = LlamaModelParallelGenerator(self.config)
|
||||
self.generator.start()
|
||||
|
|
|
@ -11,6 +11,7 @@
|
|||
# the root directory of this source tree.
|
||||
|
||||
import json
|
||||
import logging
|
||||
import multiprocessing
|
||||
import os
|
||||
import tempfile
|
||||
|
@ -37,6 +38,8 @@ from llama_stack.apis.inference import ChatCompletionRequest, CompletionRequest
|
|||
|
||||
from .generation import TokenResult
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ProcessingMessageName(str, Enum):
|
||||
ready_request = "ready_request"
|
||||
|
@ -183,16 +186,16 @@ def retrieve_requests(reply_socket_url: str):
|
|||
group=get_model_parallel_group(),
|
||||
)
|
||||
if isinstance(updates[0], CancelSentinel):
|
||||
print("quitting generation loop because request was cancelled")
|
||||
log.info(
|
||||
"quitting generation loop because request was cancelled"
|
||||
)
|
||||
break
|
||||
|
||||
if mp_rank_0():
|
||||
send_obj(EndSentinel())
|
||||
except Exception as e:
|
||||
print(f"[debug] got exception {e}")
|
||||
import traceback
|
||||
log.exception("exception in generation loop")
|
||||
|
||||
traceback.print_exc()
|
||||
if mp_rank_0():
|
||||
send_obj(ExceptionResponse(error=str(e)))
|
||||
|
||||
|
@ -252,7 +255,7 @@ def worker_process_entrypoint(
|
|||
except StopIteration:
|
||||
break
|
||||
|
||||
print("[debug] worker process done")
|
||||
log.info("[debug] worker process done")
|
||||
|
||||
|
||||
def launch_dist_group(
|
||||
|
@ -313,7 +316,7 @@ def start_model_parallel_process(
|
|||
|
||||
request_socket.send(encode_msg(ReadyRequest()))
|
||||
response = request_socket.recv()
|
||||
print("Loaded model...")
|
||||
log.info("Loaded model...")
|
||||
|
||||
return request_socket, process
|
||||
|
||||
|
@ -361,7 +364,7 @@ class ModelParallelProcessGroup:
|
|||
break
|
||||
|
||||
if isinstance(obj, ExceptionResponse):
|
||||
print(f"[debug] got exception {obj.error}")
|
||||
log.error(f"[debug] got exception {obj.error}")
|
||||
raise Exception(obj.error)
|
||||
|
||||
if isinstance(obj, TaskResponse):
|
||||
|
|
|
@ -8,14 +8,20 @@
|
|||
# This software may be used and distributed in accordance with the terms of the Llama 3 Community License Agreement.
|
||||
|
||||
import collections
|
||||
|
||||
import logging
|
||||
from typing import Optional, Type
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
try:
|
||||
import fbgemm_gpu.experimental.gen_ai # noqa: F401
|
||||
|
||||
print("Using efficient FP8 operators in FBGEMM.")
|
||||
log.info("Using efficient FP8 operators in FBGEMM.")
|
||||
except ImportError:
|
||||
print("No efficient FP8 operators. Please install FBGEMM in fp8_requirements.txt.")
|
||||
log.error(
|
||||
"No efficient FP8 operators. Please install FBGEMM in fp8_requirements.txt."
|
||||
)
|
||||
raise
|
||||
|
||||
import torch
|
||||
|
|
|
@ -7,6 +7,7 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# This software may be used and distributed in accordance with the terms of the Llama 3 Community License Agreement.
|
||||
|
||||
import logging
|
||||
import os
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
|
@ -21,7 +22,6 @@ from llama_models.llama3.api.args import ModelArgs
|
|||
from llama_models.llama3.reference_impl.model import Transformer, TransformerBlock
|
||||
from llama_models.sku_list import resolve_model
|
||||
|
||||
from termcolor import cprint
|
||||
from torch import nn, Tensor
|
||||
|
||||
from torchao.quantization.GPTQ import Int8DynActInt4WeightLinear
|
||||
|
@ -30,6 +30,8 @@ from llama_stack.apis.inference import QuantizationType
|
|||
|
||||
from ..config import MetaReferenceQuantizedInferenceConfig
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def swiglu_wrapper(
|
||||
self,
|
||||
|
@ -60,7 +62,7 @@ def convert_to_fp8_quantized_model(
|
|||
|
||||
# Move weights to GPU with quantization
|
||||
if llama_model.quantization_format == CheckpointQuantizationFormat.fp8_mixed.value:
|
||||
cprint("Loading fp8 scales...", "yellow")
|
||||
log.info("Loading fp8 scales...")
|
||||
fp8_scales_path = os.path.join(
|
||||
checkpoint_dir, f"fp8_scales_{get_model_parallel_rank()}.pt"
|
||||
)
|
||||
|
@ -85,7 +87,7 @@ def convert_to_fp8_quantized_model(
|
|||
fp8_activation_scale_ub,
|
||||
)
|
||||
else:
|
||||
cprint("Quantizing fp8 weights from bf16...", "yellow")
|
||||
log.info("Quantizing fp8 weights from bf16...")
|
||||
for block in model.layers:
|
||||
if isinstance(block, TransformerBlock):
|
||||
if block.layer_id == 0 or block.layer_id == (model.n_layers - 1):
|
||||
|
|
|
@ -8,6 +8,7 @@
|
|||
# This software may be used and distributed in accordance with the terms of the Llama 3 Community License Agreement.
|
||||
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import shutil
|
||||
import sys
|
||||
|
@ -32,6 +33,8 @@ from llama_stack.providers.inline.inference.meta_reference.quantization.fp8_impl
|
|||
quantize_fp8,
|
||||
)
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def main(
|
||||
ckpt_dir: str,
|
||||
|
@ -102,7 +105,7 @@ def main(
|
|||
else:
|
||||
torch.set_default_tensor_type(torch.cuda.HalfTensor)
|
||||
|
||||
print(ckpt_path)
|
||||
log.info(ckpt_path)
|
||||
assert (
|
||||
quantized_ckpt_dir is not None
|
||||
), "QUantized checkpoint directory should not be None"
|
||||
|
|
|
@ -4,10 +4,18 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from enum import Enum
|
||||
|
||||
from llama_models.schema_utils import json_schema_type
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class LogFormat(Enum):
|
||||
TEXT = "text"
|
||||
JSON = "json"
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class ConsoleConfig(BaseModel): ...
|
||||
class ConsoleConfig(BaseModel):
|
||||
log_format: LogFormat = LogFormat.JSON
|
||||
|
|
|
@ -4,8 +4,11 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import json
|
||||
from typing import Optional
|
||||
|
||||
from .config import LogFormat
|
||||
|
||||
from llama_stack.apis.telemetry import * # noqa: F403
|
||||
from .config import ConsoleConfig
|
||||
|
||||
|
@ -38,7 +41,11 @@ class ConsoleTelemetryImpl(Telemetry):
|
|||
|
||||
span_name = ".".join(names) if names else None
|
||||
|
||||
formatted = format_event(event, span_name)
|
||||
if self.config.log_format == LogFormat.JSON:
|
||||
formatted = format_event_json(event, span_name)
|
||||
else:
|
||||
formatted = format_event_text(event, span_name)
|
||||
|
||||
if formatted:
|
||||
print(formatted)
|
||||
|
||||
|
@ -69,7 +76,7 @@ SEVERITY_COLORS = {
|
|||
}
|
||||
|
||||
|
||||
def format_event(event: Event, span_name: str) -> Optional[str]:
|
||||
def format_event_text(event: Event, span_name: str) -> Optional[str]:
|
||||
timestamp = event.timestamp.strftime("%H:%M:%S.%f")[:-3]
|
||||
span = ""
|
||||
if span_name:
|
||||
|
@ -87,3 +94,23 @@ def format_event(event: Event, span_name: str) -> Optional[str]:
|
|||
return None
|
||||
|
||||
return f"Unknown event type: {event}"
|
||||
|
||||
|
||||
def format_event_json(event: Event, span_name: str) -> Optional[str]:
|
||||
base_data = {
|
||||
"timestamp": event.timestamp.isoformat(),
|
||||
"trace_id": event.trace_id,
|
||||
"span_id": event.span_id,
|
||||
"span_name": span_name,
|
||||
}
|
||||
|
||||
if isinstance(event, UnstructuredLogEvent):
|
||||
base_data.update(
|
||||
{"type": "log", "severity": event.severity.name, "message": event.message}
|
||||
)
|
||||
return json.dumps(base_data)
|
||||
|
||||
elif isinstance(event, StructuredLogEvent):
|
||||
return None
|
||||
|
||||
return json.dumps({"error": f"Unknown event type: {event}"})
|
||||
|
|
|
@ -4,16 +4,16 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import logging
|
||||
from typing import Any, Dict, List
|
||||
|
||||
from llama_models.llama3.api.datatypes import interleaved_text_media_as_str, Message
|
||||
from termcolor import cprint
|
||||
|
||||
from .config import CodeScannerConfig
|
||||
|
||||
from llama_stack.apis.safety import * # noqa: F403
|
||||
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
ALLOWED_CODE_SCANNER_MODEL_IDS = [
|
||||
"CodeScanner",
|
||||
"CodeShield",
|
||||
|
@ -49,7 +49,7 @@ class MetaReferenceCodeScannerSafetyImpl(Safety):
|
|||
from codeshield.cs import CodeShield
|
||||
|
||||
text = "\n".join([interleaved_text_media_as_str(m.content) for m in messages])
|
||||
cprint(f"Running CodeScannerShield on {text[50:]}", color="magenta")
|
||||
log.info(f"Running CodeScannerShield on {text[50:]}")
|
||||
result = await CodeShield.scan_code(text)
|
||||
|
||||
violation = None
|
||||
|
|
|
@ -4,10 +4,10 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import logging
|
||||
from typing import Any, Dict, List
|
||||
|
||||
import torch
|
||||
from termcolor import cprint
|
||||
|
||||
from transformers import AutoModelForSequenceClassification, AutoTokenizer
|
||||
|
||||
|
@ -20,6 +20,7 @@ from llama_stack.providers.datatypes import ShieldsProtocolPrivate
|
|||
|
||||
from .config import PromptGuardConfig, PromptGuardType
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
PROMPT_GUARD_MODEL = "Prompt-Guard-86M"
|
||||
|
||||
|
@ -93,9 +94,8 @@ class PromptGuardShield:
|
|||
probabilities = torch.softmax(logits / self.temperature, dim=-1)
|
||||
score_embedded = probabilities[0, 1].item()
|
||||
score_malicious = probabilities[0, 2].item()
|
||||
cprint(
|
||||
log.info(
|
||||
f"Ran PromptGuardShield and got Scores: Embedded: {score_embedded}, Malicious: {score_malicious}",
|
||||
color="magenta",
|
||||
)
|
||||
|
||||
violation = None
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue