fixed breaking tests and run pre-commit

This commit is contained in:
Chantal D Gama Rose 2025-02-21 07:19:40 +00:00
parent 78b1105f5d
commit 66726241aa
8 changed files with 34 additions and 81 deletions

View file

@ -20,8 +20,10 @@ The `llamastack/distribution-nvidia` distribution consists of the following prov
The following environment variables can be configured: The following environment variables can be configured:
- `LLAMASTACK_PORT`: Port for the Llama Stack distribution server (default: `5001`)
- `NVIDIA_API_KEY`: NVIDIA API Key (default: ``) - `NVIDIA_API_KEY`: NVIDIA API Key (default: ``)
- `GUARDRAILS_SERVICE_URL`: URL for the NeMo Guardrails Service (default: `http://0.0.0.0:7331`)
- `INFERENCE_MODEL`: Inference model (default: `Llama3.1-8B-Instruct`)
- `SAFETY_MODEL`: Name of the model to use for safety (default: `meta/llama-3.1-8b-instruct`)
### Models ### Models

View file

@ -35,18 +35,13 @@ class NVIDIASafetyConfig(BaseModel):
By default the configuration will attempt to read the NVIDIA_API_KEY environment By default the configuration will attempt to read the NVIDIA_API_KEY environment
variable to set the api_key. Please do not put your API key in code. variable to set the api_key. Please do not put your API key in code.
""" """
guardrails_service_url: str = Field( guardrails_service_url: str = Field(
default_factory=lambda: os.getenv("NVIDIA_BASE_URL", "http://0.0.0.0:7331"), default_factory=lambda: os.getenv("NVIDIA_BASE_URL", "http://0.0.0.0:7331"),
description="The url for accessing the guardrails service", description="The url for accessing the guardrails service",
) )
config_id: Optional[str] = Field( config_id: Optional[str] = Field(default="self-check", description="Config ID to use from the config store")
default="self-check", config_store_path: Optional[str] = Field(default="/config-store", description="Path to config store")
description="Config ID to use from the config store"
)
config_store_path: Optional[str] = Field(
default="/config-store",
description="Path to config store"
)
@classmethod @classmethod
@field_validator("guard_type") @field_validator("guard_type")
@ -59,5 +54,5 @@ class NVIDIASafetyConfig(BaseModel):
def sample_run_config(cls, **kwargs) -> Dict[str, Any]: def sample_run_config(cls, **kwargs) -> Dict[str, Any]:
return { return {
"guardrails_service_url": "${env.GUARDRAILS_SERVICE_URL:http://localhost:7331}", "guardrails_service_url": "${env.GUARDRAILS_SERVICE_URL:http://localhost:7331}",
"config_id": "self-check" "config_id": "self-check",
} }

View file

@ -25,17 +25,6 @@ from .config import NVIDIASafetyConfig
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
SHIELD_IDS_TO_MODEL_MAPPING = {
CoreModelId.llama3_8b_instruct.value: "meta/llama3-8b-instruct",
CoreModelId.llama3_70b_instruct.value: "meta/llama3-70b-instruct",
CoreModelId.llama3_1_8b_instruct.value: "meta/llama-3.1-8b-instruct",
CoreModelId.llama3_1_70b_instruct.value: "meta/llama-3.1-70b-instruct",
CoreModelId.llama3_1_405b_instruct.value: "meta/llama-3.1-405b-instruct",
CoreModelId.llama3_2_1b_instruct.value: "meta/llama-3.2-1b-instruct",
CoreModelId.llama3_2_3b_instruct.value: "meta/llama-3.2-3b-instruct",
CoreModelId.llama3_2_11b_vision_instruct.value: "meta/llama-3.2-11b-vision-instruct",
CoreModelId.llama3_2_90b_vision_instruct.value: "meta/llama-3.2-90b-vision-instruct"
}
class NVIDIASafetyAdapter(Safety, ShieldsProtocolPrivate): class NVIDIASafetyAdapter(Safety, ShieldsProtocolPrivate):
def __init__(self, config: NVIDIASafetyConfig) -> None: def __init__(self, config: NVIDIASafetyConfig) -> None:
@ -59,7 +48,7 @@ class NVIDIASafetyAdapter(Safety, ShieldsProtocolPrivate):
shield = await self.shield_store.get_shield(shield_id) shield = await self.shield_store.get_shield(shield_id)
if not shield: if not shield:
raise ValueError(f"Shield {shield_id} not found") raise ValueError(f"Shield {shield_id} not found")
# print(shield.provider_shield_id)
self.shield = NeMoGuardrails(self.config, shield.shield_id) self.shield = NeMoGuardrails(self.config, shield.shield_id)
return await self.shield.run(messages) return await self.shield.run(messages)
@ -75,7 +64,9 @@ class NeMoGuardrails:
self.config_id = config.config_id self.config_id = config.config_id
self.config_store_path = config.config_store_path self.config_store_path = config.config_store_path
self.model = model self.model = model
assert self.config_id is not None or self.config_store_path is not None, "Must provide one of config id or config store path" assert self.config_id is not None or self.config_store_path is not None, (
"Must provide one of config id or config store path"
)
if temperature <= 0: if temperature <= 0:
raise ValueError("Temperature must be greater than 0") raise ValueError("Temperature must be greater than 0")
@ -99,16 +90,14 @@ class NeMoGuardrails:
"stream": False, "stream": False,
"guardrails": { "guardrails": {
"config_id": self.config_id, "config_id": self.config_id,
} },
} }
response = requests.post( response = requests.post(
url=f"{self.guardrails_service_url}/v1/guardrail/checks", url=f"{self.guardrails_service_url}/v1/guardrail/checks", headers=headers, json=request_data
headers=headers,
json=request_data
) )
print(response) print(response)
response.raise_for_status() response.raise_for_status()
if 'Content-Type' in response.headers and response.headers['Content-Type'].startswith('application/json'): if "Content-Type" in response.headers and response.headers["Content-Type"].startswith("application/json"):
response_json = response.json() response_json = response.json()
if response_json["status"] == "blocked": if response_json["status"] == "blocked":
user_message = "Sorry I cannot do this." user_message = "Sorry I cannot do this."

View file

@ -20,7 +20,8 @@ from llama_stack.providers.remote.inference.cerebras import CerebrasImplConfig
from llama_stack.providers.remote.inference.fireworks import FireworksImplConfig from llama_stack.providers.remote.inference.fireworks import FireworksImplConfig
from llama_stack.providers.remote.inference.groq import GroqConfig from llama_stack.providers.remote.inference.groq import GroqConfig
from llama_stack.providers.remote.inference.nvidia import NVIDIAConfig from llama_stack.providers.remote.inference.nvidia import NVIDIAConfig
from llama_stack.providers.remote.inference.ollama import DEFAULT_OLLAMA_URL, OllamaImplConfig from llama_stack.providers.remote.inference.ollama import OllamaImplConfig
from llama_stack.providers.remote.inference.ollama.config import DEFAULT_OLLAMA_URL
from llama_stack.providers.remote.inference.sambanova import SambaNovaImplConfig from llama_stack.providers.remote.inference.sambanova import SambaNovaImplConfig
from llama_stack.providers.remote.inference.tgi import TGIImplConfig from llama_stack.providers.remote.inference.tgi import TGIImplConfig
from llama_stack.providers.remote.inference.together import TogetherImplConfig from llama_stack.providers.remote.inference.together import TogetherImplConfig

View file

@ -95,6 +95,7 @@ def safety_bedrock() -> ProviderFixture:
], ],
) )
@pytest.fixture(scope="session") @pytest.fixture(scope="session")
def safety_nvidia() -> ProviderFixture: def safety_nvidia() -> ProviderFixture:
return ProviderFixture( return ProviderFixture(

View file

@ -20,7 +20,7 @@ class TestSafety:
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_shield_list(self, safety_stack): async def test_shield_list(self, safety_stack):
_, shields_impl, _ = safety_stack _, shields_impl, _ = safety_stack
response = await shields_impl.list_shields() response = (await shields_impl.list_shields()).data
assert isinstance(response, list) assert isinstance(response, list)
assert len(response) >= 1 assert len(response) >= 1

View file

@ -7,7 +7,7 @@ distribution_spec:
vector_io: vector_io:
- inline::faiss - inline::faiss
safety: safety:
- remote::nvidia - inline::llama-guard
agents: agents:
- inline::meta-reference - inline::meta-reference
telemetry: telemetry:

View file

@ -17,6 +17,11 @@ providers:
config: config:
url: ${env.NVIDIA_BASE_URL:https://integrate.api.nvidia.com} url: ${env.NVIDIA_BASE_URL:https://integrate.api.nvidia.com}
api_key: ${env.NVIDIA_API_KEY:} api_key: ${env.NVIDIA_API_KEY:}
- provider_id: nvidia
provider_type: remote::nvidia
config:
guardrails_service_url: ${env.GUARDRAILS_SERVICE_URL:http://localhost:7331}
config_id: self-check
vector_io: vector_io:
- provider_id: faiss - provider_id: faiss
provider_type: inline::faiss provider_type: inline::faiss
@ -26,11 +31,9 @@ providers:
namespace: null namespace: null
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/nvidia}/faiss_store.db db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/nvidia}/faiss_store.db
safety: safety:
- provider_id: nvidia - provider_id: llama-guard
provider_type: remote::nvidia provider_type: inline::llama-guard
config: config: {}
url: ${env.GUARDRAILS_SERVICE_URL:http://localhost:7331}
config_id: self-check
agents: agents:
- provider_id: meta-reference - provider_id: meta-reference
provider_type: inline::meta-reference provider_type: inline::meta-reference
@ -93,57 +96,19 @@ metadata_store:
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/nvidia}/registry.db db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/nvidia}/registry.db
models: models:
- metadata: {} - metadata: {}
model_id: meta-llama/Llama-3-8B-Instruct model_id: ${env.INFERENCE_MODEL}
provider_id: nvidia provider_id: nvidia
provider_model_id: meta/llama3-8b-instruct
model_type: llm model_type: llm
- metadata: {} - metadata: {}
model_id: meta-llama/Llama-3-70B-Instruct model_id: ${env.SAFETY_MODEL}
provider_id: nvidia provider_id: nvidia
provider_model_id: meta/llama3-70b-instruct
model_type: llm
- metadata: {}
model_id: meta-llama/Llama-3.1-8B-Instruct
provider_id: nvidia
provider_model_id: meta/llama-3.1-8b-instruct
model_type: llm
- metadata: {}
model_id: meta-llama/Llama-3.1-70B-Instruct
provider_id: nvidia
provider_model_id: meta/llama-3.1-70b-instruct
model_type: llm
- metadata: {}
model_id: meta-llama/Llama-3.1-405B-Instruct-FP8
provider_id: nvidia
provider_model_id: meta/llama-3.1-405b-instruct
model_type: llm
- metadata: {}
model_id: meta-llama/Llama-3.2-1B-Instruct
provider_id: nvidia
provider_model_id: meta/llama-3.2-1b-instruct
model_type: llm
- metadata: {}
model_id: meta-llama/Llama-3.2-3B-Instruct
provider_id: nvidia
provider_model_id: meta/llama-3.2-3b-instruct
model_type: llm
- metadata: {}
model_id: meta-llama/Llama-3.2-11B-Vision-Instruct
provider_id: nvidia
provider_model_id: meta/llama-3.2-11b-vision-instruct
model_type: llm
- metadata: {}
model_id: meta-llama/Llama-3.2-90B-Vision-Instruct
provider_id: nvidia
provider_model_id: meta/llama-3.2-90b-vision-instruct
model_type: llm model_type: llm
shields: shields:
- shield_id: ${env.SAFETY_MODEL} - shield_id: ${env.SAFETY_MODEL}
provider_id: nvidia
vector_dbs: [] vector_dbs: []
datasets: [] datasets: []
scoring_fns: [] scoring_fns: []
eval_tasks: [] benchmarks: []
tool_groups: tool_groups:
- toolgroup_id: builtin::websearch - toolgroup_id: builtin::websearch
provider_id: tavily-search provider_id: tavily-search