mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-08-05 10:13:05 +00:00
fix: Run prompt_guard model in a seperate thread
The GPU model usage blocks the CPU. Move it to its own thread. Also wrap in a lock to prevent multiple simultaneous run from exhausting the GPU. Closes: #1746 Signed-off-by: Derek Higgins <derekh@redhat.com>
This commit is contained in:
parent
18bac27d4e
commit
6434cdfdab
1 changed files with 7 additions and 1 deletions
|
@ -4,6 +4,7 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
from typing import Any, Dict, List
|
||||
|
||||
|
@ -29,6 +30,10 @@ from .config import PromptGuardConfig, PromptGuardType
|
|||
log = logging.getLogger(__name__)
|
||||
|
||||
PROMPT_GUARD_MODEL = "Prompt-Guard-86M"
|
||||
# We are using a thread to prevent the model usage from blocking the event loop
|
||||
# But to ensure multiple model runs don't exhaust the GPU memory we only run
|
||||
# 1 at a time.
|
||||
MODEL_LOCK = asyncio.Lock()
|
||||
|
||||
|
||||
class PromptGuardSafetyImpl(Safety, ShieldsProtocolPrivate):
|
||||
|
@ -89,7 +94,8 @@ class PromptGuardShield:
|
|||
inputs = self.tokenizer(text, return_tensors="pt")
|
||||
inputs = {name: tensor.to(self.model.device) for name, tensor in inputs.items()}
|
||||
with torch.no_grad():
|
||||
outputs = self.model(**inputs)
|
||||
async with MODEL_LOCK:
|
||||
outputs = await asyncio.to_thread(self.model, **inputs)
|
||||
logits = outputs[0]
|
||||
probabilities = torch.softmax(logits / self.temperature, dim=-1)
|
||||
score_embedded = probabilities[0, 1].item()
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue