mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-08-12 13:00:39 +00:00
fixed breaking tests and run pre-commit
This commit is contained in:
parent
78b1105f5d
commit
66726241aa
8 changed files with 34 additions and 81 deletions
|
@ -20,8 +20,10 @@ The `llamastack/distribution-nvidia` distribution consists of the following prov
|
|||
|
||||
The following environment variables can be configured:
|
||||
|
||||
- `LLAMASTACK_PORT`: Port for the Llama Stack distribution server (default: `5001`)
|
||||
- `NVIDIA_API_KEY`: NVIDIA API Key (default: ``)
|
||||
- `GUARDRAILS_SERVICE_URL`: URL for the NeMo Guardrails Service (default: `http://0.0.0.0:7331`)
|
||||
- `INFERENCE_MODEL`: Inference model (default: `Llama3.1-8B-Instruct`)
|
||||
- `SAFETY_MODEL`: Name of the model to use for safety (default: `meta/llama-3.1-8b-instruct`)
|
||||
|
||||
### Models
|
||||
|
||||
|
|
|
@ -35,18 +35,13 @@ class NVIDIASafetyConfig(BaseModel):
|
|||
By default the configuration will attempt to read the NVIDIA_API_KEY environment
|
||||
variable to set the api_key. Please do not put your API key in code.
|
||||
"""
|
||||
|
||||
guardrails_service_url: str = Field(
|
||||
default_factory=lambda: os.getenv("NVIDIA_BASE_URL", "http://0.0.0.0:7331"),
|
||||
description="The url for accessing the guardrails service",
|
||||
)
|
||||
config_id: Optional[str] = Field(
|
||||
default="self-check",
|
||||
description="Config ID to use from the config store"
|
||||
)
|
||||
config_store_path: Optional[str] = Field(
|
||||
default="/config-store",
|
||||
description="Path to config store"
|
||||
)
|
||||
config_id: Optional[str] = Field(default="self-check", description="Config ID to use from the config store")
|
||||
config_store_path: Optional[str] = Field(default="/config-store", description="Path to config store")
|
||||
|
||||
@classmethod
|
||||
@field_validator("guard_type")
|
||||
|
@ -54,10 +49,10 @@ class NVIDIASafetyConfig(BaseModel):
|
|||
if v not in [t.value for t in ShieldType]:
|
||||
raise ValueError(f"Unknown shield type: {v}")
|
||||
return v
|
||||
|
||||
|
||||
@classmethod
|
||||
def sample_run_config(cls, **kwargs) -> Dict[str, Any]:
|
||||
return {
|
||||
"guardrails_service_url": "${env.GUARDRAILS_SERVICE_URL:http://localhost:7331}",
|
||||
"config_id": "self-check"
|
||||
"config_id": "self-check",
|
||||
}
|
||||
|
|
|
@ -25,17 +25,6 @@ from .config import NVIDIASafetyConfig
|
|||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
SHIELD_IDS_TO_MODEL_MAPPING = {
|
||||
CoreModelId.llama3_8b_instruct.value: "meta/llama3-8b-instruct",
|
||||
CoreModelId.llama3_70b_instruct.value: "meta/llama3-70b-instruct",
|
||||
CoreModelId.llama3_1_8b_instruct.value: "meta/llama-3.1-8b-instruct",
|
||||
CoreModelId.llama3_1_70b_instruct.value: "meta/llama-3.1-70b-instruct",
|
||||
CoreModelId.llama3_1_405b_instruct.value: "meta/llama-3.1-405b-instruct",
|
||||
CoreModelId.llama3_2_1b_instruct.value: "meta/llama-3.2-1b-instruct",
|
||||
CoreModelId.llama3_2_3b_instruct.value: "meta/llama-3.2-3b-instruct",
|
||||
CoreModelId.llama3_2_11b_vision_instruct.value: "meta/llama-3.2-11b-vision-instruct",
|
||||
CoreModelId.llama3_2_90b_vision_instruct.value: "meta/llama-3.2-90b-vision-instruct"
|
||||
}
|
||||
|
||||
class NVIDIASafetyAdapter(Safety, ShieldsProtocolPrivate):
|
||||
def __init__(self, config: NVIDIASafetyConfig) -> None:
|
||||
|
@ -59,11 +48,11 @@ class NVIDIASafetyAdapter(Safety, ShieldsProtocolPrivate):
|
|||
shield = await self.shield_store.get_shield(shield_id)
|
||||
if not shield:
|
||||
raise ValueError(f"Shield {shield_id} not found")
|
||||
# print(shield.provider_shield_id)
|
||||
|
||||
self.shield = NeMoGuardrails(self.config, shield.shield_id)
|
||||
return await self.shield.run(messages)
|
||||
|
||||
|
||||
|
||||
|
||||
class NeMoGuardrails:
|
||||
def __init__(
|
||||
self,
|
||||
|
@ -75,7 +64,9 @@ class NeMoGuardrails:
|
|||
self.config_id = config.config_id
|
||||
self.config_store_path = config.config_store_path
|
||||
self.model = model
|
||||
assert self.config_id is not None or self.config_store_path is not None, "Must provide one of config id or config store path"
|
||||
assert self.config_id is not None or self.config_store_path is not None, (
|
||||
"Must provide one of config id or config store path"
|
||||
)
|
||||
if temperature <= 0:
|
||||
raise ValueError("Temperature must be greater than 0")
|
||||
|
||||
|
@ -99,21 +90,19 @@ class NeMoGuardrails:
|
|||
"stream": False,
|
||||
"guardrails": {
|
||||
"config_id": self.config_id,
|
||||
}
|
||||
},
|
||||
}
|
||||
response = requests.post(
|
||||
url=f"{self.guardrails_service_url}/v1/guardrail/checks",
|
||||
headers=headers,
|
||||
json=request_data
|
||||
url=f"{self.guardrails_service_url}/v1/guardrail/checks", headers=headers, json=request_data
|
||||
)
|
||||
print(response)
|
||||
response.raise_for_status()
|
||||
if 'Content-Type' in response.headers and response.headers['Content-Type'].startswith('application/json'):
|
||||
if "Content-Type" in response.headers and response.headers["Content-Type"].startswith("application/json"):
|
||||
response_json = response.json()
|
||||
if response_json["status"] == "blocked":
|
||||
user_message = "Sorry I cannot do this."
|
||||
metadata = response_json["rails_status"]
|
||||
|
||||
|
||||
return RunShieldResponse(
|
||||
violation=SafetyViolation(
|
||||
user_message=user_message,
|
||||
|
|
|
@ -20,7 +20,8 @@ from llama_stack.providers.remote.inference.cerebras import CerebrasImplConfig
|
|||
from llama_stack.providers.remote.inference.fireworks import FireworksImplConfig
|
||||
from llama_stack.providers.remote.inference.groq import GroqConfig
|
||||
from llama_stack.providers.remote.inference.nvidia import NVIDIAConfig
|
||||
from llama_stack.providers.remote.inference.ollama import DEFAULT_OLLAMA_URL, OllamaImplConfig
|
||||
from llama_stack.providers.remote.inference.ollama import OllamaImplConfig
|
||||
from llama_stack.providers.remote.inference.ollama.config import DEFAULT_OLLAMA_URL
|
||||
from llama_stack.providers.remote.inference.sambanova import SambaNovaImplConfig
|
||||
from llama_stack.providers.remote.inference.tgi import TGIImplConfig
|
||||
from llama_stack.providers.remote.inference.together import TogetherImplConfig
|
||||
|
|
|
@ -95,6 +95,7 @@ def safety_bedrock() -> ProviderFixture:
|
|||
],
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def safety_nvidia() -> ProviderFixture:
|
||||
return ProviderFixture(
|
||||
|
|
|
@ -20,7 +20,7 @@ class TestSafety:
|
|||
@pytest.mark.asyncio
|
||||
async def test_shield_list(self, safety_stack):
|
||||
_, shields_impl, _ = safety_stack
|
||||
response = await shields_impl.list_shields()
|
||||
response = (await shields_impl.list_shields()).data
|
||||
assert isinstance(response, list)
|
||||
assert len(response) >= 1
|
||||
|
||||
|
|
|
@ -7,7 +7,7 @@ distribution_spec:
|
|||
vector_io:
|
||||
- inline::faiss
|
||||
safety:
|
||||
- remote::nvidia
|
||||
- inline::llama-guard
|
||||
agents:
|
||||
- inline::meta-reference
|
||||
telemetry:
|
||||
|
|
|
@ -17,6 +17,11 @@ providers:
|
|||
config:
|
||||
url: ${env.NVIDIA_BASE_URL:https://integrate.api.nvidia.com}
|
||||
api_key: ${env.NVIDIA_API_KEY:}
|
||||
- provider_id: nvidia
|
||||
provider_type: remote::nvidia
|
||||
config:
|
||||
guardrails_service_url: ${env.GUARDRAILS_SERVICE_URL:http://localhost:7331}
|
||||
config_id: self-check
|
||||
vector_io:
|
||||
- provider_id: faiss
|
||||
provider_type: inline::faiss
|
||||
|
@ -26,11 +31,9 @@ providers:
|
|||
namespace: null
|
||||
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/nvidia}/faiss_store.db
|
||||
safety:
|
||||
- provider_id: nvidia
|
||||
provider_type: remote::nvidia
|
||||
config:
|
||||
url: ${env.GUARDRAILS_SERVICE_URL:http://localhost:7331}
|
||||
config_id: self-check
|
||||
- provider_id: llama-guard
|
||||
provider_type: inline::llama-guard
|
||||
config: {}
|
||||
agents:
|
||||
- provider_id: meta-reference
|
||||
provider_type: inline::meta-reference
|
||||
|
@ -93,57 +96,19 @@ metadata_store:
|
|||
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/nvidia}/registry.db
|
||||
models:
|
||||
- metadata: {}
|
||||
model_id: meta-llama/Llama-3-8B-Instruct
|
||||
model_id: ${env.INFERENCE_MODEL}
|
||||
provider_id: nvidia
|
||||
provider_model_id: meta/llama3-8b-instruct
|
||||
model_type: llm
|
||||
- metadata: {}
|
||||
model_id: meta-llama/Llama-3-70B-Instruct
|
||||
model_id: ${env.SAFETY_MODEL}
|
||||
provider_id: nvidia
|
||||
provider_model_id: meta/llama3-70b-instruct
|
||||
model_type: llm
|
||||
- metadata: {}
|
||||
model_id: meta-llama/Llama-3.1-8B-Instruct
|
||||
provider_id: nvidia
|
||||
provider_model_id: meta/llama-3.1-8b-instruct
|
||||
model_type: llm
|
||||
- metadata: {}
|
||||
model_id: meta-llama/Llama-3.1-70B-Instruct
|
||||
provider_id: nvidia
|
||||
provider_model_id: meta/llama-3.1-70b-instruct
|
||||
model_type: llm
|
||||
- metadata: {}
|
||||
model_id: meta-llama/Llama-3.1-405B-Instruct-FP8
|
||||
provider_id: nvidia
|
||||
provider_model_id: meta/llama-3.1-405b-instruct
|
||||
model_type: llm
|
||||
- metadata: {}
|
||||
model_id: meta-llama/Llama-3.2-1B-Instruct
|
||||
provider_id: nvidia
|
||||
provider_model_id: meta/llama-3.2-1b-instruct
|
||||
model_type: llm
|
||||
- metadata: {}
|
||||
model_id: meta-llama/Llama-3.2-3B-Instruct
|
||||
provider_id: nvidia
|
||||
provider_model_id: meta/llama-3.2-3b-instruct
|
||||
model_type: llm
|
||||
- metadata: {}
|
||||
model_id: meta-llama/Llama-3.2-11B-Vision-Instruct
|
||||
provider_id: nvidia
|
||||
provider_model_id: meta/llama-3.2-11b-vision-instruct
|
||||
model_type: llm
|
||||
- metadata: {}
|
||||
model_id: meta-llama/Llama-3.2-90B-Vision-Instruct
|
||||
provider_id: nvidia
|
||||
provider_model_id: meta/llama-3.2-90b-vision-instruct
|
||||
model_type: llm
|
||||
shields:
|
||||
- shield_id: ${env.SAFETY_MODEL}
|
||||
provider_id: nvidia
|
||||
vector_dbs: []
|
||||
datasets: []
|
||||
scoring_fns: []
|
||||
eval_tasks: []
|
||||
benchmarks: []
|
||||
tool_groups:
|
||||
- toolgroup_id: builtin::websearch
|
||||
provider_id: tavily-search
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue