mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-08-06 18:50:44 +00:00
fix: Run prompt_guard model in a seperate thread
The GPU model usage blocks the CPU. Move it to its own thread. Also wrap in a lock to prevent multiple simultaneous run from exhausting the GPU. Closes: #1746 Signed-off-by: Derek Higgins <derekh@redhat.com>
This commit is contained in:
parent
18bac27d4e
commit
6434cdfdab
1 changed files with 7 additions and 1 deletions
|
@ -4,6 +4,7 @@
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
import asyncio
|
||||||
import logging
|
import logging
|
||||||
from typing import Any, Dict, List
|
from typing import Any, Dict, List
|
||||||
|
|
||||||
|
@ -29,6 +30,10 @@ from .config import PromptGuardConfig, PromptGuardType
|
||||||
log = logging.getLogger(__name__)
|
log = logging.getLogger(__name__)
|
||||||
|
|
||||||
PROMPT_GUARD_MODEL = "Prompt-Guard-86M"
|
PROMPT_GUARD_MODEL = "Prompt-Guard-86M"
|
||||||
|
# We are using a thread to prevent the model usage from blocking the event loop
|
||||||
|
# But to ensure multiple model runs don't exhaust the GPU memory we only run
|
||||||
|
# 1 at a time.
|
||||||
|
MODEL_LOCK = asyncio.Lock()
|
||||||
|
|
||||||
|
|
||||||
class PromptGuardSafetyImpl(Safety, ShieldsProtocolPrivate):
|
class PromptGuardSafetyImpl(Safety, ShieldsProtocolPrivate):
|
||||||
|
@ -89,7 +94,8 @@ class PromptGuardShield:
|
||||||
inputs = self.tokenizer(text, return_tensors="pt")
|
inputs = self.tokenizer(text, return_tensors="pt")
|
||||||
inputs = {name: tensor.to(self.model.device) for name, tensor in inputs.items()}
|
inputs = {name: tensor.to(self.model.device) for name, tensor in inputs.items()}
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
outputs = self.model(**inputs)
|
async with MODEL_LOCK:
|
||||||
|
outputs = await asyncio.to_thread(self.model, **inputs)
|
||||||
logits = outputs[0]
|
logits = outputs[0]
|
||||||
probabilities = torch.softmax(logits / self.temperature, dim=-1)
|
probabilities = torch.softmax(logits / self.temperature, dim=-1)
|
||||||
score_embedded = probabilities[0, 1].item()
|
score_embedded = probabilities[0, 1].item()
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue