mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-08-13 21:29:57 +00:00
fixed breaking tests and run pre-commit
This commit is contained in:
parent
78b1105f5d
commit
66726241aa
8 changed files with 34 additions and 81 deletions
|
@ -20,8 +20,10 @@ The `llamastack/distribution-nvidia` distribution consists of the following prov
|
||||||
|
|
||||||
The following environment variables can be configured:
|
The following environment variables can be configured:
|
||||||
|
|
||||||
- `LLAMASTACK_PORT`: Port for the Llama Stack distribution server (default: `5001`)
|
|
||||||
- `NVIDIA_API_KEY`: NVIDIA API Key (default: ``)
|
- `NVIDIA_API_KEY`: NVIDIA API Key (default: ``)
|
||||||
|
- `GUARDRAILS_SERVICE_URL`: URL for the NeMo Guardrails Service (default: `http://0.0.0.0:7331`)
|
||||||
|
- `INFERENCE_MODEL`: Inference model (default: `Llama3.1-8B-Instruct`)
|
||||||
|
- `SAFETY_MODEL`: Name of the model to use for safety (default: `meta/llama-3.1-8b-instruct`)
|
||||||
|
|
||||||
### Models
|
### Models
|
||||||
|
|
||||||
|
|
|
@ -35,18 +35,13 @@ class NVIDIASafetyConfig(BaseModel):
|
||||||
By default the configuration will attempt to read the NVIDIA_API_KEY environment
|
By default the configuration will attempt to read the NVIDIA_API_KEY environment
|
||||||
variable to set the api_key. Please do not put your API key in code.
|
variable to set the api_key. Please do not put your API key in code.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
guardrails_service_url: str = Field(
|
guardrails_service_url: str = Field(
|
||||||
default_factory=lambda: os.getenv("NVIDIA_BASE_URL", "http://0.0.0.0:7331"),
|
default_factory=lambda: os.getenv("NVIDIA_BASE_URL", "http://0.0.0.0:7331"),
|
||||||
description="The url for accessing the guardrails service",
|
description="The url for accessing the guardrails service",
|
||||||
)
|
)
|
||||||
config_id: Optional[str] = Field(
|
config_id: Optional[str] = Field(default="self-check", description="Config ID to use from the config store")
|
||||||
default="self-check",
|
config_store_path: Optional[str] = Field(default="/config-store", description="Path to config store")
|
||||||
description="Config ID to use from the config store"
|
|
||||||
)
|
|
||||||
config_store_path: Optional[str] = Field(
|
|
||||||
default="/config-store",
|
|
||||||
description="Path to config store"
|
|
||||||
)
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@field_validator("guard_type")
|
@field_validator("guard_type")
|
||||||
|
@ -59,5 +54,5 @@ class NVIDIASafetyConfig(BaseModel):
|
||||||
def sample_run_config(cls, **kwargs) -> Dict[str, Any]:
|
def sample_run_config(cls, **kwargs) -> Dict[str, Any]:
|
||||||
return {
|
return {
|
||||||
"guardrails_service_url": "${env.GUARDRAILS_SERVICE_URL:http://localhost:7331}",
|
"guardrails_service_url": "${env.GUARDRAILS_SERVICE_URL:http://localhost:7331}",
|
||||||
"config_id": "self-check"
|
"config_id": "self-check",
|
||||||
}
|
}
|
||||||
|
|
|
@ -25,17 +25,6 @@ from .config import NVIDIASafetyConfig
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
SHIELD_IDS_TO_MODEL_MAPPING = {
|
|
||||||
CoreModelId.llama3_8b_instruct.value: "meta/llama3-8b-instruct",
|
|
||||||
CoreModelId.llama3_70b_instruct.value: "meta/llama3-70b-instruct",
|
|
||||||
CoreModelId.llama3_1_8b_instruct.value: "meta/llama-3.1-8b-instruct",
|
|
||||||
CoreModelId.llama3_1_70b_instruct.value: "meta/llama-3.1-70b-instruct",
|
|
||||||
CoreModelId.llama3_1_405b_instruct.value: "meta/llama-3.1-405b-instruct",
|
|
||||||
CoreModelId.llama3_2_1b_instruct.value: "meta/llama-3.2-1b-instruct",
|
|
||||||
CoreModelId.llama3_2_3b_instruct.value: "meta/llama-3.2-3b-instruct",
|
|
||||||
CoreModelId.llama3_2_11b_vision_instruct.value: "meta/llama-3.2-11b-vision-instruct",
|
|
||||||
CoreModelId.llama3_2_90b_vision_instruct.value: "meta/llama-3.2-90b-vision-instruct"
|
|
||||||
}
|
|
||||||
|
|
||||||
class NVIDIASafetyAdapter(Safety, ShieldsProtocolPrivate):
|
class NVIDIASafetyAdapter(Safety, ShieldsProtocolPrivate):
|
||||||
def __init__(self, config: NVIDIASafetyConfig) -> None:
|
def __init__(self, config: NVIDIASafetyConfig) -> None:
|
||||||
|
@ -59,7 +48,7 @@ class NVIDIASafetyAdapter(Safety, ShieldsProtocolPrivate):
|
||||||
shield = await self.shield_store.get_shield(shield_id)
|
shield = await self.shield_store.get_shield(shield_id)
|
||||||
if not shield:
|
if not shield:
|
||||||
raise ValueError(f"Shield {shield_id} not found")
|
raise ValueError(f"Shield {shield_id} not found")
|
||||||
# print(shield.provider_shield_id)
|
|
||||||
self.shield = NeMoGuardrails(self.config, shield.shield_id)
|
self.shield = NeMoGuardrails(self.config, shield.shield_id)
|
||||||
return await self.shield.run(messages)
|
return await self.shield.run(messages)
|
||||||
|
|
||||||
|
@ -75,7 +64,9 @@ class NeMoGuardrails:
|
||||||
self.config_id = config.config_id
|
self.config_id = config.config_id
|
||||||
self.config_store_path = config.config_store_path
|
self.config_store_path = config.config_store_path
|
||||||
self.model = model
|
self.model = model
|
||||||
assert self.config_id is not None or self.config_store_path is not None, "Must provide one of config id or config store path"
|
assert self.config_id is not None or self.config_store_path is not None, (
|
||||||
|
"Must provide one of config id or config store path"
|
||||||
|
)
|
||||||
if temperature <= 0:
|
if temperature <= 0:
|
||||||
raise ValueError("Temperature must be greater than 0")
|
raise ValueError("Temperature must be greater than 0")
|
||||||
|
|
||||||
|
@ -99,16 +90,14 @@ class NeMoGuardrails:
|
||||||
"stream": False,
|
"stream": False,
|
||||||
"guardrails": {
|
"guardrails": {
|
||||||
"config_id": self.config_id,
|
"config_id": self.config_id,
|
||||||
}
|
},
|
||||||
}
|
}
|
||||||
response = requests.post(
|
response = requests.post(
|
||||||
url=f"{self.guardrails_service_url}/v1/guardrail/checks",
|
url=f"{self.guardrails_service_url}/v1/guardrail/checks", headers=headers, json=request_data
|
||||||
headers=headers,
|
|
||||||
json=request_data
|
|
||||||
)
|
)
|
||||||
print(response)
|
print(response)
|
||||||
response.raise_for_status()
|
response.raise_for_status()
|
||||||
if 'Content-Type' in response.headers and response.headers['Content-Type'].startswith('application/json'):
|
if "Content-Type" in response.headers and response.headers["Content-Type"].startswith("application/json"):
|
||||||
response_json = response.json()
|
response_json = response.json()
|
||||||
if response_json["status"] == "blocked":
|
if response_json["status"] == "blocked":
|
||||||
user_message = "Sorry I cannot do this."
|
user_message = "Sorry I cannot do this."
|
||||||
|
|
|
@ -20,7 +20,8 @@ from llama_stack.providers.remote.inference.cerebras import CerebrasImplConfig
|
||||||
from llama_stack.providers.remote.inference.fireworks import FireworksImplConfig
|
from llama_stack.providers.remote.inference.fireworks import FireworksImplConfig
|
||||||
from llama_stack.providers.remote.inference.groq import GroqConfig
|
from llama_stack.providers.remote.inference.groq import GroqConfig
|
||||||
from llama_stack.providers.remote.inference.nvidia import NVIDIAConfig
|
from llama_stack.providers.remote.inference.nvidia import NVIDIAConfig
|
||||||
from llama_stack.providers.remote.inference.ollama import DEFAULT_OLLAMA_URL, OllamaImplConfig
|
from llama_stack.providers.remote.inference.ollama import OllamaImplConfig
|
||||||
|
from llama_stack.providers.remote.inference.ollama.config import DEFAULT_OLLAMA_URL
|
||||||
from llama_stack.providers.remote.inference.sambanova import SambaNovaImplConfig
|
from llama_stack.providers.remote.inference.sambanova import SambaNovaImplConfig
|
||||||
from llama_stack.providers.remote.inference.tgi import TGIImplConfig
|
from llama_stack.providers.remote.inference.tgi import TGIImplConfig
|
||||||
from llama_stack.providers.remote.inference.together import TogetherImplConfig
|
from llama_stack.providers.remote.inference.together import TogetherImplConfig
|
||||||
|
|
|
@ -95,6 +95,7 @@ def safety_bedrock() -> ProviderFixture:
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="session")
|
@pytest.fixture(scope="session")
|
||||||
def safety_nvidia() -> ProviderFixture:
|
def safety_nvidia() -> ProviderFixture:
|
||||||
return ProviderFixture(
|
return ProviderFixture(
|
||||||
|
|
|
@ -20,7 +20,7 @@ class TestSafety:
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_shield_list(self, safety_stack):
|
async def test_shield_list(self, safety_stack):
|
||||||
_, shields_impl, _ = safety_stack
|
_, shields_impl, _ = safety_stack
|
||||||
response = await shields_impl.list_shields()
|
response = (await shields_impl.list_shields()).data
|
||||||
assert isinstance(response, list)
|
assert isinstance(response, list)
|
||||||
assert len(response) >= 1
|
assert len(response) >= 1
|
||||||
|
|
||||||
|
|
|
@ -7,7 +7,7 @@ distribution_spec:
|
||||||
vector_io:
|
vector_io:
|
||||||
- inline::faiss
|
- inline::faiss
|
||||||
safety:
|
safety:
|
||||||
- remote::nvidia
|
- inline::llama-guard
|
||||||
agents:
|
agents:
|
||||||
- inline::meta-reference
|
- inline::meta-reference
|
||||||
telemetry:
|
telemetry:
|
||||||
|
|
|
@ -17,6 +17,11 @@ providers:
|
||||||
config:
|
config:
|
||||||
url: ${env.NVIDIA_BASE_URL:https://integrate.api.nvidia.com}
|
url: ${env.NVIDIA_BASE_URL:https://integrate.api.nvidia.com}
|
||||||
api_key: ${env.NVIDIA_API_KEY:}
|
api_key: ${env.NVIDIA_API_KEY:}
|
||||||
|
- provider_id: nvidia
|
||||||
|
provider_type: remote::nvidia
|
||||||
|
config:
|
||||||
|
guardrails_service_url: ${env.GUARDRAILS_SERVICE_URL:http://localhost:7331}
|
||||||
|
config_id: self-check
|
||||||
vector_io:
|
vector_io:
|
||||||
- provider_id: faiss
|
- provider_id: faiss
|
||||||
provider_type: inline::faiss
|
provider_type: inline::faiss
|
||||||
|
@ -26,11 +31,9 @@ providers:
|
||||||
namespace: null
|
namespace: null
|
||||||
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/nvidia}/faiss_store.db
|
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/nvidia}/faiss_store.db
|
||||||
safety:
|
safety:
|
||||||
- provider_id: nvidia
|
- provider_id: llama-guard
|
||||||
provider_type: remote::nvidia
|
provider_type: inline::llama-guard
|
||||||
config:
|
config: {}
|
||||||
url: ${env.GUARDRAILS_SERVICE_URL:http://localhost:7331}
|
|
||||||
config_id: self-check
|
|
||||||
agents:
|
agents:
|
||||||
- provider_id: meta-reference
|
- provider_id: meta-reference
|
||||||
provider_type: inline::meta-reference
|
provider_type: inline::meta-reference
|
||||||
|
@ -93,57 +96,19 @@ metadata_store:
|
||||||
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/nvidia}/registry.db
|
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/nvidia}/registry.db
|
||||||
models:
|
models:
|
||||||
- metadata: {}
|
- metadata: {}
|
||||||
model_id: meta-llama/Llama-3-8B-Instruct
|
model_id: ${env.INFERENCE_MODEL}
|
||||||
provider_id: nvidia
|
provider_id: nvidia
|
||||||
provider_model_id: meta/llama3-8b-instruct
|
|
||||||
model_type: llm
|
model_type: llm
|
||||||
- metadata: {}
|
- metadata: {}
|
||||||
model_id: meta-llama/Llama-3-70B-Instruct
|
model_id: ${env.SAFETY_MODEL}
|
||||||
provider_id: nvidia
|
provider_id: nvidia
|
||||||
provider_model_id: meta/llama3-70b-instruct
|
|
||||||
model_type: llm
|
|
||||||
- metadata: {}
|
|
||||||
model_id: meta-llama/Llama-3.1-8B-Instruct
|
|
||||||
provider_id: nvidia
|
|
||||||
provider_model_id: meta/llama-3.1-8b-instruct
|
|
||||||
model_type: llm
|
|
||||||
- metadata: {}
|
|
||||||
model_id: meta-llama/Llama-3.1-70B-Instruct
|
|
||||||
provider_id: nvidia
|
|
||||||
provider_model_id: meta/llama-3.1-70b-instruct
|
|
||||||
model_type: llm
|
|
||||||
- metadata: {}
|
|
||||||
model_id: meta-llama/Llama-3.1-405B-Instruct-FP8
|
|
||||||
provider_id: nvidia
|
|
||||||
provider_model_id: meta/llama-3.1-405b-instruct
|
|
||||||
model_type: llm
|
|
||||||
- metadata: {}
|
|
||||||
model_id: meta-llama/Llama-3.2-1B-Instruct
|
|
||||||
provider_id: nvidia
|
|
||||||
provider_model_id: meta/llama-3.2-1b-instruct
|
|
||||||
model_type: llm
|
|
||||||
- metadata: {}
|
|
||||||
model_id: meta-llama/Llama-3.2-3B-Instruct
|
|
||||||
provider_id: nvidia
|
|
||||||
provider_model_id: meta/llama-3.2-3b-instruct
|
|
||||||
model_type: llm
|
|
||||||
- metadata: {}
|
|
||||||
model_id: meta-llama/Llama-3.2-11B-Vision-Instruct
|
|
||||||
provider_id: nvidia
|
|
||||||
provider_model_id: meta/llama-3.2-11b-vision-instruct
|
|
||||||
model_type: llm
|
|
||||||
- metadata: {}
|
|
||||||
model_id: meta-llama/Llama-3.2-90B-Vision-Instruct
|
|
||||||
provider_id: nvidia
|
|
||||||
provider_model_id: meta/llama-3.2-90b-vision-instruct
|
|
||||||
model_type: llm
|
model_type: llm
|
||||||
shields:
|
shields:
|
||||||
- shield_id: ${env.SAFETY_MODEL}
|
- shield_id: ${env.SAFETY_MODEL}
|
||||||
provider_id: nvidia
|
|
||||||
vector_dbs: []
|
vector_dbs: []
|
||||||
datasets: []
|
datasets: []
|
||||||
scoring_fns: []
|
scoring_fns: []
|
||||||
eval_tasks: []
|
benchmarks: []
|
||||||
tool_groups:
|
tool_groups:
|
||||||
- toolgroup_id: builtin::websearch
|
- toolgroup_id: builtin::websearch
|
||||||
provider_id: tavily-search
|
provider_id: tavily-search
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue