From 66726241aae6472e12bf4b67b1ff102fab76e18c Mon Sep 17 00:00:00 2001 From: Chantal D Gama Rose Date: Fri, 21 Feb 2025 07:19:40 +0000 Subject: [PATCH] fixed breaking tests and run pre-commit --- .../remote_hosted_distro/nvidia.md | 4 +- .../providers/remote/safety/nvidia/config.py | 15 ++--- .../providers/remote/safety/nvidia/nvidia.py | 31 ++++------ .../providers/tests/inference/fixtures.py | 3 +- .../providers/tests/safety/fixtures.py | 1 + .../providers/tests/safety/test_safety.py | 2 +- llama_stack/templates/nvidia/build.yaml | 2 +- .../templates/nvidia/run-with-safety.yaml | 57 ++++--------------- 8 files changed, 34 insertions(+), 81 deletions(-) diff --git a/docs/source/distributions/remote_hosted_distro/nvidia.md b/docs/source/distributions/remote_hosted_distro/nvidia.md index f352f737e..e2a0e1253 100644 --- a/docs/source/distributions/remote_hosted_distro/nvidia.md +++ b/docs/source/distributions/remote_hosted_distro/nvidia.md @@ -20,8 +20,10 @@ The `llamastack/distribution-nvidia` distribution consists of the following prov The following environment variables can be configured: -- `LLAMASTACK_PORT`: Port for the Llama Stack distribution server (default: `5001`) - `NVIDIA_API_KEY`: NVIDIA API Key (default: ``) +- `GUARDRAILS_SERVICE_URL`: URL for the NeMo Guardrails Service (default: `http://0.0.0.0:7331`) +- `INFERENCE_MODEL`: Inference model (default: `Llama3.1-8B-Instruct`) +- `SAFETY_MODEL`: Name of the model to use for safety (default: `meta/llama-3.1-8b-instruct`) ### Models diff --git a/llama_stack/providers/remote/safety/nvidia/config.py b/llama_stack/providers/remote/safety/nvidia/config.py index c3d7b609d..e688db1b9 100644 --- a/llama_stack/providers/remote/safety/nvidia/config.py +++ b/llama_stack/providers/remote/safety/nvidia/config.py @@ -35,18 +35,13 @@ class NVIDIASafetyConfig(BaseModel): By default the configuration will attempt to read the NVIDIA_API_KEY environment variable to set the api_key. Please do not put your API key in code. """ + guardrails_service_url: str = Field( default_factory=lambda: os.getenv("NVIDIA_BASE_URL", "http://0.0.0.0:7331"), description="The url for accessing the guardrails service", ) - config_id: Optional[str] = Field( - default="self-check", - description="Config ID to use from the config store" - ) - config_store_path: Optional[str] = Field( - default="/config-store", - description="Path to config store" - ) + config_id: Optional[str] = Field(default="self-check", description="Config ID to use from the config store") + config_store_path: Optional[str] = Field(default="/config-store", description="Path to config store") @classmethod @field_validator("guard_type") @@ -54,10 +49,10 @@ class NVIDIASafetyConfig(BaseModel): if v not in [t.value for t in ShieldType]: raise ValueError(f"Unknown shield type: {v}") return v - + @classmethod def sample_run_config(cls, **kwargs) -> Dict[str, Any]: return { "guardrails_service_url": "${env.GUARDRAILS_SERVICE_URL:http://localhost:7331}", - "config_id": "self-check" + "config_id": "self-check", } diff --git a/llama_stack/providers/remote/safety/nvidia/nvidia.py b/llama_stack/providers/remote/safety/nvidia/nvidia.py index 39379c2eb..11bdd14a2 100644 --- a/llama_stack/providers/remote/safety/nvidia/nvidia.py +++ b/llama_stack/providers/remote/safety/nvidia/nvidia.py @@ -25,17 +25,6 @@ from .config import NVIDIASafetyConfig logger = logging.getLogger(__name__) -SHIELD_IDS_TO_MODEL_MAPPING = { - CoreModelId.llama3_8b_instruct.value: "meta/llama3-8b-instruct", - CoreModelId.llama3_70b_instruct.value: "meta/llama3-70b-instruct", - CoreModelId.llama3_1_8b_instruct.value: "meta/llama-3.1-8b-instruct", - CoreModelId.llama3_1_70b_instruct.value: "meta/llama-3.1-70b-instruct", - CoreModelId.llama3_1_405b_instruct.value: "meta/llama-3.1-405b-instruct", - CoreModelId.llama3_2_1b_instruct.value: "meta/llama-3.2-1b-instruct", - CoreModelId.llama3_2_3b_instruct.value: "meta/llama-3.2-3b-instruct", - CoreModelId.llama3_2_11b_vision_instruct.value: "meta/llama-3.2-11b-vision-instruct", - CoreModelId.llama3_2_90b_vision_instruct.value: "meta/llama-3.2-90b-vision-instruct" -} class NVIDIASafetyAdapter(Safety, ShieldsProtocolPrivate): def __init__(self, config: NVIDIASafetyConfig) -> None: @@ -59,11 +48,11 @@ class NVIDIASafetyAdapter(Safety, ShieldsProtocolPrivate): shield = await self.shield_store.get_shield(shield_id) if not shield: raise ValueError(f"Shield {shield_id} not found") - # print(shield.provider_shield_id) + self.shield = NeMoGuardrails(self.config, shield.shield_id) return await self.shield.run(messages) - - + + class NeMoGuardrails: def __init__( self, @@ -75,7 +64,9 @@ class NeMoGuardrails: self.config_id = config.config_id self.config_store_path = config.config_store_path self.model = model - assert self.config_id is not None or self.config_store_path is not None, "Must provide one of config id or config store path" + assert self.config_id is not None or self.config_store_path is not None, ( + "Must provide one of config id or config store path" + ) if temperature <= 0: raise ValueError("Temperature must be greater than 0") @@ -99,21 +90,19 @@ class NeMoGuardrails: "stream": False, "guardrails": { "config_id": self.config_id, - } + }, } response = requests.post( - url=f"{self.guardrails_service_url}/v1/guardrail/checks", - headers=headers, - json=request_data + url=f"{self.guardrails_service_url}/v1/guardrail/checks", headers=headers, json=request_data ) print(response) response.raise_for_status() - if 'Content-Type' in response.headers and response.headers['Content-Type'].startswith('application/json'): + if "Content-Type" in response.headers and response.headers["Content-Type"].startswith("application/json"): response_json = response.json() if response_json["status"] == "blocked": user_message = "Sorry I cannot do this." metadata = response_json["rails_status"] - + return RunShieldResponse( violation=SafetyViolation( user_message=user_message, diff --git a/llama_stack/providers/tests/inference/fixtures.py b/llama_stack/providers/tests/inference/fixtures.py index b553b6b02..5291bffb3 100644 --- a/llama_stack/providers/tests/inference/fixtures.py +++ b/llama_stack/providers/tests/inference/fixtures.py @@ -20,7 +20,8 @@ from llama_stack.providers.remote.inference.cerebras import CerebrasImplConfig from llama_stack.providers.remote.inference.fireworks import FireworksImplConfig from llama_stack.providers.remote.inference.groq import GroqConfig from llama_stack.providers.remote.inference.nvidia import NVIDIAConfig -from llama_stack.providers.remote.inference.ollama import DEFAULT_OLLAMA_URL, OllamaImplConfig +from llama_stack.providers.remote.inference.ollama import OllamaImplConfig +from llama_stack.providers.remote.inference.ollama.config import DEFAULT_OLLAMA_URL from llama_stack.providers.remote.inference.sambanova import SambaNovaImplConfig from llama_stack.providers.remote.inference.tgi import TGIImplConfig from llama_stack.providers.remote.inference.together import TogetherImplConfig diff --git a/llama_stack/providers/tests/safety/fixtures.py b/llama_stack/providers/tests/safety/fixtures.py index 3f85473c5..6271cd58a 100644 --- a/llama_stack/providers/tests/safety/fixtures.py +++ b/llama_stack/providers/tests/safety/fixtures.py @@ -95,6 +95,7 @@ def safety_bedrock() -> ProviderFixture: ], ) + @pytest.fixture(scope="session") def safety_nvidia() -> ProviderFixture: return ProviderFixture( diff --git a/llama_stack/providers/tests/safety/test_safety.py b/llama_stack/providers/tests/safety/test_safety.py index 101f2224f..dab7fc186 100644 --- a/llama_stack/providers/tests/safety/test_safety.py +++ b/llama_stack/providers/tests/safety/test_safety.py @@ -20,7 +20,7 @@ class TestSafety: @pytest.mark.asyncio async def test_shield_list(self, safety_stack): _, shields_impl, _ = safety_stack - response = await shields_impl.list_shields() + response = (await shields_impl.list_shields()).data assert isinstance(response, list) assert len(response) >= 1 diff --git a/llama_stack/templates/nvidia/build.yaml b/llama_stack/templates/nvidia/build.yaml index 63a227d4f..e9748721a 100644 --- a/llama_stack/templates/nvidia/build.yaml +++ b/llama_stack/templates/nvidia/build.yaml @@ -7,7 +7,7 @@ distribution_spec: vector_io: - inline::faiss safety: - - remote::nvidia + - inline::llama-guard agents: - inline::meta-reference telemetry: diff --git a/llama_stack/templates/nvidia/run-with-safety.yaml b/llama_stack/templates/nvidia/run-with-safety.yaml index bfb346c7d..ec52e3ef1 100644 --- a/llama_stack/templates/nvidia/run-with-safety.yaml +++ b/llama_stack/templates/nvidia/run-with-safety.yaml @@ -17,6 +17,11 @@ providers: config: url: ${env.NVIDIA_BASE_URL:https://integrate.api.nvidia.com} api_key: ${env.NVIDIA_API_KEY:} + - provider_id: nvidia + provider_type: remote::nvidia + config: + guardrails_service_url: ${env.GUARDRAILS_SERVICE_URL:http://localhost:7331} + config_id: self-check vector_io: - provider_id: faiss provider_type: inline::faiss @@ -26,11 +31,9 @@ providers: namespace: null db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/nvidia}/faiss_store.db safety: - - provider_id: nvidia - provider_type: remote::nvidia - config: - url: ${env.GUARDRAILS_SERVICE_URL:http://localhost:7331} - config_id: self-check + - provider_id: llama-guard + provider_type: inline::llama-guard + config: {} agents: - provider_id: meta-reference provider_type: inline::meta-reference @@ -93,57 +96,19 @@ metadata_store: db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/nvidia}/registry.db models: - metadata: {} - model_id: meta-llama/Llama-3-8B-Instruct + model_id: ${env.INFERENCE_MODEL} provider_id: nvidia - provider_model_id: meta/llama3-8b-instruct model_type: llm - metadata: {} - model_id: meta-llama/Llama-3-70B-Instruct + model_id: ${env.SAFETY_MODEL} provider_id: nvidia - provider_model_id: meta/llama3-70b-instruct - model_type: llm -- metadata: {} - model_id: meta-llama/Llama-3.1-8B-Instruct - provider_id: nvidia - provider_model_id: meta/llama-3.1-8b-instruct - model_type: llm -- metadata: {} - model_id: meta-llama/Llama-3.1-70B-Instruct - provider_id: nvidia - provider_model_id: meta/llama-3.1-70b-instruct - model_type: llm -- metadata: {} - model_id: meta-llama/Llama-3.1-405B-Instruct-FP8 - provider_id: nvidia - provider_model_id: meta/llama-3.1-405b-instruct - model_type: llm -- metadata: {} - model_id: meta-llama/Llama-3.2-1B-Instruct - provider_id: nvidia - provider_model_id: meta/llama-3.2-1b-instruct - model_type: llm -- metadata: {} - model_id: meta-llama/Llama-3.2-3B-Instruct - provider_id: nvidia - provider_model_id: meta/llama-3.2-3b-instruct - model_type: llm -- metadata: {} - model_id: meta-llama/Llama-3.2-11B-Vision-Instruct - provider_id: nvidia - provider_model_id: meta/llama-3.2-11b-vision-instruct - model_type: llm -- metadata: {} - model_id: meta-llama/Llama-3.2-90B-Vision-Instruct - provider_id: nvidia - provider_model_id: meta/llama-3.2-90b-vision-instruct model_type: llm shields: - shield_id: ${env.SAFETY_MODEL} - provider_id: nvidia vector_dbs: [] datasets: [] scoring_fns: [] -eval_tasks: [] +benchmarks: [] tool_groups: - toolgroup_id: builtin::websearch provider_id: tavily-search