run-with-safety memory

This commit is contained in:
Xi Yan 2024-12-03 20:54:59 -08:00
parent bc2452d2e9
commit 26f578cc1d
13 changed files with 18 additions and 47 deletions

View file

@ -69,7 +69,8 @@ def get_distribution_template() -> DistributionTemplate:
endpoint_name="${env.SAFETY_INFERENCE_ENDPOINT_NAME}", endpoint_name="${env.SAFETY_INFERENCE_ENDPOINT_NAME}",
), ),
), ),
] ],
"memory": [memory_provider],
}, },
default_models=[ default_models=[
inference_model, inference_model,

View file

@ -31,12 +31,6 @@ providers:
type: sqlite type: sqlite
namespace: null namespace: null
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/hf-endpoint}/faiss_store.db db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/hf-endpoint}/faiss_store.db
- provider_id: chromadb
provider_type: remote::chromadb
config: {}
- provider_id: pgvector
provider_type: remote::pgvector
config: {}
safety: safety:
- provider_id: llama-guard - provider_id: llama-guard
provider_type: inline::llama-guard provider_type: inline::llama-guard

View file

@ -70,7 +70,8 @@ def get_distribution_template() -> DistributionTemplate:
repo="${env.SAFETY_MODEL}", repo="${env.SAFETY_MODEL}",
), ),
), ),
] ],
"memory": [memory_provider],
}, },
default_models=[ default_models=[
inference_model, inference_model,

View file

@ -31,12 +31,6 @@ providers:
type: sqlite type: sqlite
namespace: null namespace: null
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/hf-serverless}/faiss_store.db db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/hf-serverless}/faiss_store.db
- provider_id: chromadb
provider_type: remote::chromadb
config: {}
- provider_id: pgvector
provider_type: remote::pgvector
config: {}
safety: safety:
- provider_id: llama-guard - provider_id: llama-guard
provider_type: inline::llama-guard provider_type: inline::llama-guard

View file

@ -77,6 +77,7 @@ def get_distribution_template() -> DistributionTemplate:
), ),
), ),
], ],
"memory": [memory_provider],
}, },
default_models=[ default_models=[
inference_model, inference_model,

View file

@ -33,12 +33,6 @@ providers:
type: sqlite type: sqlite
namespace: null namespace: null
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/meta-reference-gpu}/faiss_store.db db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/meta-reference-gpu}/faiss_store.db
- provider_id: chromadb
provider_type: remote::chromadb
config: {}
- provider_id: pgvector
provider_type: remote::pgvector
config: {}
safety: safety:
- provider_id: llama-guard - provider_id: llama-guard
provider_type: inline::llama-guard provider_type: inline::llama-guard

View file

@ -7,6 +7,7 @@
from pathlib import Path from pathlib import Path
from llama_stack.distribution.datatypes import ModelInput, Provider, ShieldInput from llama_stack.distribution.datatypes import ModelInput, Provider, ShieldInput
from llama_stack.providers.inline.memory.faiss.config import FaissImplConfig
from llama_stack.providers.remote.inference.ollama import OllamaImplConfig from llama_stack.providers.remote.inference.ollama import OllamaImplConfig
from llama_stack.templates.template import DistributionTemplate, RunConfigSettings from llama_stack.templates.template import DistributionTemplate, RunConfigSettings
@ -22,12 +23,17 @@ def get_distribution_template() -> DistributionTemplate:
"datasetio": ["remote::huggingface", "inline::localfs"], "datasetio": ["remote::huggingface", "inline::localfs"],
"scoring": ["inline::basic", "inline::llm-as-judge", "inline::braintrust"], "scoring": ["inline::basic", "inline::llm-as-judge", "inline::braintrust"],
} }
name = "ollama"
inference_provider = Provider( inference_provider = Provider(
provider_id="ollama", provider_id="ollama",
provider_type="remote::ollama", provider_type="remote::ollama",
config=OllamaImplConfig.sample_run_config(), config=OllamaImplConfig.sample_run_config(),
) )
memory_provider = Provider(
provider_id="faiss",
provider_type="inline::faiss",
config=FaissImplConfig.sample_run_config(f"distributions/{name}"),
)
inference_model = ModelInput( inference_model = ModelInput(
model_id="${env.INFERENCE_MODEL}", model_id="${env.INFERENCE_MODEL}",
@ -39,7 +45,7 @@ def get_distribution_template() -> DistributionTemplate:
) )
return DistributionTemplate( return DistributionTemplate(
name="ollama", name=name,
distro_type="self_hosted", distro_type="self_hosted",
description="Use (an external) Ollama server for running LLM inference", description="Use (an external) Ollama server for running LLM inference",
docker_image=None, docker_image=None,
@ -50,6 +56,7 @@ def get_distribution_template() -> DistributionTemplate:
"run.yaml": RunConfigSettings( "run.yaml": RunConfigSettings(
provider_overrides={ provider_overrides={
"inference": [inference_provider], "inference": [inference_provider],
"memory": [memory_provider],
}, },
default_models=[inference_model], default_models=[inference_model],
), ),
@ -57,7 +64,8 @@ def get_distribution_template() -> DistributionTemplate:
provider_overrides={ provider_overrides={
"inference": [ "inference": [
inference_provider, inference_provider,
] ],
"memory": [memory_provider],
}, },
default_models=[ default_models=[
inference_model, inference_model,

View file

@ -25,12 +25,6 @@ providers:
type: sqlite type: sqlite
namespace: null namespace: null
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/ollama}/faiss_store.db db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/ollama}/faiss_store.db
- provider_id: chromadb
provider_type: remote::chromadb
config: {}
- provider_id: pgvector
provider_type: remote::pgvector
config: {}
safety: safety:
- provider_id: llama-guard - provider_id: llama-guard
provider_type: inline::llama-guard provider_type: inline::llama-guard

View file

@ -25,12 +25,6 @@ providers:
type: sqlite type: sqlite
namespace: null namespace: null
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/ollama}/faiss_store.db db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/ollama}/faiss_store.db
- provider_id: chromadb
provider_type: remote::chromadb
config: {}
- provider_id: pgvector
provider_type: remote::pgvector
config: {}
safety: safety:
- provider_id: llama-guard - provider_id: llama-guard
provider_type: inline::llama-guard provider_type: inline::llama-guard

View file

@ -30,12 +30,6 @@ providers:
type: sqlite type: sqlite
namespace: null namespace: null
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/remote-vllm}/faiss_store.db db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/remote-vllm}/faiss_store.db
- provider_id: chromadb
provider_type: remote::chromadb
config: {}
- provider_id: pgvector
provider_type: remote::pgvector
config: {}
safety: safety:
- provider_id: llama-guard - provider_id: llama-guard
provider_type: inline::llama-guard provider_type: inline::llama-guard

View file

@ -70,6 +70,7 @@ def get_distribution_template() -> DistributionTemplate:
), ),
), ),
], ],
"memory": [memory_provider],
}, },
default_models=[ default_models=[
inference_model, inference_model,

View file

@ -29,12 +29,6 @@ providers:
type: sqlite type: sqlite
namespace: null namespace: null
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/tgi}/faiss_store.db db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/tgi}/faiss_store.db
- provider_id: chromadb
provider_type: remote::chromadb
config: {}
- provider_id: pgvector
provider_type: remote::pgvector
config: {}
safety: safety:
- provider_id: llama-guard - provider_id: llama-guard
provider_type: inline::llama-guard provider_type: inline::llama-guard

View file

@ -74,6 +74,7 @@ def get_distribution_template() -> DistributionTemplate:
), ),
), ),
], ],
"memory": [memory_provider],
}, },
default_models=[ default_models=[
inference_model, inference_model,