fixed breaking tests and run pre-commit

This commit is contained in:
Chantal D Gama Rose 2025-02-21 07:19:40 +00:00
parent 78b1105f5d
commit 66726241aa
8 changed files with 34 additions and 81 deletions

View file

@ -20,8 +20,10 @@ The `llamastack/distribution-nvidia` distribution consists of the following prov
The following environment variables can be configured:
- `LLAMASTACK_PORT`: Port for the Llama Stack distribution server (default: `5001`)
- `NVIDIA_API_KEY`: NVIDIA API Key (default: ``)
- `GUARDRAILS_SERVICE_URL`: URL for the NeMo Guardrails Service (default: `http://0.0.0.0:7331`)
- `INFERENCE_MODEL`: Inference model (default: `Llama3.1-8B-Instruct`)
- `SAFETY_MODEL`: Name of the model to use for safety (default: `meta/llama-3.1-8b-instruct`)
### Models

View file

@ -35,18 +35,13 @@ class NVIDIASafetyConfig(BaseModel):
By default the configuration will attempt to read the NVIDIA_API_KEY environment
variable to set the api_key. Please do not put your API key in code.
"""
guardrails_service_url: str = Field(
default_factory=lambda: os.getenv("NVIDIA_BASE_URL", "http://0.0.0.0:7331"),
description="The url for accessing the guardrails service",
)
config_id: Optional[str] = Field(
default="self-check",
description="Config ID to use from the config store"
)
config_store_path: Optional[str] = Field(
default="/config-store",
description="Path to config store"
)
config_id: Optional[str] = Field(default="self-check", description="Config ID to use from the config store")
config_store_path: Optional[str] = Field(default="/config-store", description="Path to config store")
@classmethod
@field_validator("guard_type")
@ -54,10 +49,10 @@ class NVIDIASafetyConfig(BaseModel):
if v not in [t.value for t in ShieldType]:
raise ValueError(f"Unknown shield type: {v}")
return v
@classmethod
def sample_run_config(cls, **kwargs) -> Dict[str, Any]:
return {
"guardrails_service_url": "${env.GUARDRAILS_SERVICE_URL:http://localhost:7331}",
"config_id": "self-check"
"config_id": "self-check",
}

View file

@ -25,17 +25,6 @@ from .config import NVIDIASafetyConfig
logger = logging.getLogger(__name__)
SHIELD_IDS_TO_MODEL_MAPPING = {
CoreModelId.llama3_8b_instruct.value: "meta/llama3-8b-instruct",
CoreModelId.llama3_70b_instruct.value: "meta/llama3-70b-instruct",
CoreModelId.llama3_1_8b_instruct.value: "meta/llama-3.1-8b-instruct",
CoreModelId.llama3_1_70b_instruct.value: "meta/llama-3.1-70b-instruct",
CoreModelId.llama3_1_405b_instruct.value: "meta/llama-3.1-405b-instruct",
CoreModelId.llama3_2_1b_instruct.value: "meta/llama-3.2-1b-instruct",
CoreModelId.llama3_2_3b_instruct.value: "meta/llama-3.2-3b-instruct",
CoreModelId.llama3_2_11b_vision_instruct.value: "meta/llama-3.2-11b-vision-instruct",
CoreModelId.llama3_2_90b_vision_instruct.value: "meta/llama-3.2-90b-vision-instruct"
}
class NVIDIASafetyAdapter(Safety, ShieldsProtocolPrivate):
def __init__(self, config: NVIDIASafetyConfig) -> None:
@ -59,11 +48,11 @@ class NVIDIASafetyAdapter(Safety, ShieldsProtocolPrivate):
shield = await self.shield_store.get_shield(shield_id)
if not shield:
raise ValueError(f"Shield {shield_id} not found")
# print(shield.provider_shield_id)
self.shield = NeMoGuardrails(self.config, shield.shield_id)
return await self.shield.run(messages)
class NeMoGuardrails:
def __init__(
self,
@ -75,7 +64,9 @@ class NeMoGuardrails:
self.config_id = config.config_id
self.config_store_path = config.config_store_path
self.model = model
assert self.config_id is not None or self.config_store_path is not None, "Must provide one of config id or config store path"
assert self.config_id is not None or self.config_store_path is not None, (
"Must provide one of config id or config store path"
)
if temperature <= 0:
raise ValueError("Temperature must be greater than 0")
@ -99,21 +90,19 @@ class NeMoGuardrails:
"stream": False,
"guardrails": {
"config_id": self.config_id,
}
},
}
response = requests.post(
url=f"{self.guardrails_service_url}/v1/guardrail/checks",
headers=headers,
json=request_data
url=f"{self.guardrails_service_url}/v1/guardrail/checks", headers=headers, json=request_data
)
print(response)
response.raise_for_status()
if 'Content-Type' in response.headers and response.headers['Content-Type'].startswith('application/json'):
if "Content-Type" in response.headers and response.headers["Content-Type"].startswith("application/json"):
response_json = response.json()
if response_json["status"] == "blocked":
user_message = "Sorry I cannot do this."
metadata = response_json["rails_status"]
return RunShieldResponse(
violation=SafetyViolation(
user_message=user_message,

View file

@ -20,7 +20,8 @@ from llama_stack.providers.remote.inference.cerebras import CerebrasImplConfig
from llama_stack.providers.remote.inference.fireworks import FireworksImplConfig
from llama_stack.providers.remote.inference.groq import GroqConfig
from llama_stack.providers.remote.inference.nvidia import NVIDIAConfig
from llama_stack.providers.remote.inference.ollama import DEFAULT_OLLAMA_URL, OllamaImplConfig
from llama_stack.providers.remote.inference.ollama import OllamaImplConfig
from llama_stack.providers.remote.inference.ollama.config import DEFAULT_OLLAMA_URL
from llama_stack.providers.remote.inference.sambanova import SambaNovaImplConfig
from llama_stack.providers.remote.inference.tgi import TGIImplConfig
from llama_stack.providers.remote.inference.together import TogetherImplConfig

View file

@ -95,6 +95,7 @@ def safety_bedrock() -> ProviderFixture:
],
)
@pytest.fixture(scope="session")
def safety_nvidia() -> ProviderFixture:
return ProviderFixture(

View file

@ -20,7 +20,7 @@ class TestSafety:
@pytest.mark.asyncio
async def test_shield_list(self, safety_stack):
_, shields_impl, _ = safety_stack
response = await shields_impl.list_shields()
response = (await shields_impl.list_shields()).data
assert isinstance(response, list)
assert len(response) >= 1

View file

@ -7,7 +7,7 @@ distribution_spec:
vector_io:
- inline::faiss
safety:
- remote::nvidia
- inline::llama-guard
agents:
- inline::meta-reference
telemetry:

View file

@ -17,6 +17,11 @@ providers:
config:
url: ${env.NVIDIA_BASE_URL:https://integrate.api.nvidia.com}
api_key: ${env.NVIDIA_API_KEY:}
- provider_id: nvidia
provider_type: remote::nvidia
config:
guardrails_service_url: ${env.GUARDRAILS_SERVICE_URL:http://localhost:7331}
config_id: self-check
vector_io:
- provider_id: faiss
provider_type: inline::faiss
@ -26,11 +31,9 @@ providers:
namespace: null
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/nvidia}/faiss_store.db
safety:
- provider_id: nvidia
provider_type: remote::nvidia
config:
url: ${env.GUARDRAILS_SERVICE_URL:http://localhost:7331}
config_id: self-check
- provider_id: llama-guard
provider_type: inline::llama-guard
config: {}
agents:
- provider_id: meta-reference
provider_type: inline::meta-reference
@ -93,57 +96,19 @@ metadata_store:
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/nvidia}/registry.db
models:
- metadata: {}
model_id: meta-llama/Llama-3-8B-Instruct
model_id: ${env.INFERENCE_MODEL}
provider_id: nvidia
provider_model_id: meta/llama3-8b-instruct
model_type: llm
- metadata: {}
model_id: meta-llama/Llama-3-70B-Instruct
model_id: ${env.SAFETY_MODEL}
provider_id: nvidia
provider_model_id: meta/llama3-70b-instruct
model_type: llm
- metadata: {}
model_id: meta-llama/Llama-3.1-8B-Instruct
provider_id: nvidia
provider_model_id: meta/llama-3.1-8b-instruct
model_type: llm
- metadata: {}
model_id: meta-llama/Llama-3.1-70B-Instruct
provider_id: nvidia
provider_model_id: meta/llama-3.1-70b-instruct
model_type: llm
- metadata: {}
model_id: meta-llama/Llama-3.1-405B-Instruct-FP8
provider_id: nvidia
provider_model_id: meta/llama-3.1-405b-instruct
model_type: llm
- metadata: {}
model_id: meta-llama/Llama-3.2-1B-Instruct
provider_id: nvidia
provider_model_id: meta/llama-3.2-1b-instruct
model_type: llm
- metadata: {}
model_id: meta-llama/Llama-3.2-3B-Instruct
provider_id: nvidia
provider_model_id: meta/llama-3.2-3b-instruct
model_type: llm
- metadata: {}
model_id: meta-llama/Llama-3.2-11B-Vision-Instruct
provider_id: nvidia
provider_model_id: meta/llama-3.2-11b-vision-instruct
model_type: llm
- metadata: {}
model_id: meta-llama/Llama-3.2-90B-Vision-Instruct
provider_id: nvidia
provider_model_id: meta/llama-3.2-90b-vision-instruct
model_type: llm
shields:
- shield_id: ${env.SAFETY_MODEL}
provider_id: nvidia
vector_dbs: []
datasets: []
scoring_fns: []
eval_tasks: []
benchmarks: []
tool_groups:
- toolgroup_id: builtin::websearch
provider_id: tavily-search