diff --git a/llama_stack/providers/inline/safety/prompt_guard/prompt_guard.py b/llama_stack/providers/inline/safety/prompt_guard/prompt_guard.py index fce3e3d14..37a0debac 100644 --- a/llama_stack/providers/inline/safety/prompt_guard/prompt_guard.py +++ b/llama_stack/providers/inline/safety/prompt_guard/prompt_guard.py @@ -4,6 +4,7 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. +import asyncio import logging from typing import Any, Dict, List @@ -29,6 +30,10 @@ from .config import PromptGuardConfig, PromptGuardType log = logging.getLogger(__name__) PROMPT_GUARD_MODEL = "Prompt-Guard-86M" +# We are using a thread to prevent the model usage from blocking the event loop +# But to ensure multiple model runs don't exhaust the GPU memory we only run +# 1 at a time. +MODEL_LOCK = asyncio.Lock() class PromptGuardSafetyImpl(Safety, ShieldsProtocolPrivate): @@ -89,7 +94,8 @@ class PromptGuardShield: inputs = self.tokenizer(text, return_tensors="pt") inputs = {name: tensor.to(self.model.device) for name, tensor in inputs.items()} with torch.no_grad(): - outputs = self.model(**inputs) + async with MODEL_LOCK: + outputs = await asyncio.to_thread(self.model, **inputs) logits = outputs[0] probabilities = torch.softmax(logits / self.temperature, dim=-1) score_embedded = probabilities[0, 1].item()