mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-07-29 07:14:20 +00:00
add more RoutableProviders
This commit is contained in:
parent
c17c17cb19
commit
878b2c31c7
7 changed files with 100 additions and 43 deletions
|
@ -40,6 +40,10 @@ class CommonRoutingTableImpl(RoutingTable):
|
||||||
|
|
||||||
async def initialize(self) -> None:
|
async def initialize(self) -> None:
|
||||||
for keys, p in self.unique_providers:
|
for keys, p in self.unique_providers:
|
||||||
|
spec = p.__provider_spec__
|
||||||
|
if isinstance(spec, RemoteProviderSpec) and spec.adapter is None:
|
||||||
|
continue
|
||||||
|
|
||||||
await p.register_routing_keys(keys)
|
await p.register_routing_keys(keys)
|
||||||
|
|
||||||
async def shutdown(self) -> None:
|
async def shutdown(self) -> None:
|
||||||
|
|
|
@ -13,7 +13,7 @@ import chromadb
|
||||||
from numpy.typing import NDArray
|
from numpy.typing import NDArray
|
||||||
|
|
||||||
from llama_stack.apis.memory import * # noqa: F403
|
from llama_stack.apis.memory import * # noqa: F403
|
||||||
|
from llama_stack.distribution.datatypes import RoutableProvider
|
||||||
|
|
||||||
from llama_stack.providers.utils.memory.vector_store import (
|
from llama_stack.providers.utils.memory.vector_store import (
|
||||||
BankWithIndex,
|
BankWithIndex,
|
||||||
|
@ -65,7 +65,7 @@ class ChromaIndex(EmbeddingIndex):
|
||||||
return QueryDocumentsResponse(chunks=chunks, scores=scores)
|
return QueryDocumentsResponse(chunks=chunks, scores=scores)
|
||||||
|
|
||||||
|
|
||||||
class ChromaMemoryAdapter(Memory):
|
class ChromaMemoryAdapter(Memory, RoutableProvider):
|
||||||
def __init__(self, url: str) -> None:
|
def __init__(self, url: str) -> None:
|
||||||
print(f"Initializing ChromaMemoryAdapter with url: {url}")
|
print(f"Initializing ChromaMemoryAdapter with url: {url}")
|
||||||
url = url.rstrip("/")
|
url = url.rstrip("/")
|
||||||
|
@ -93,6 +93,13 @@ class ChromaMemoryAdapter(Memory):
|
||||||
async def shutdown(self) -> None:
|
async def shutdown(self) -> None:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
async def register_routing_keys(self, routing_keys: List[str]) -> None:
|
||||||
|
print(f"[chroma] Registering memory bank routing keys: {routing_keys}")
|
||||||
|
self.routing_keys = routing_keys
|
||||||
|
|
||||||
|
def get_routing_keys(self) -> List[str]:
|
||||||
|
return self.routing_keys
|
||||||
|
|
||||||
async def create_memory_bank(
|
async def create_memory_bank(
|
||||||
self,
|
self,
|
||||||
name: str,
|
name: str,
|
||||||
|
|
|
@ -5,16 +5,17 @@
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
import uuid
|
import uuid
|
||||||
|
|
||||||
from typing import List, Tuple
|
from typing import List, Tuple
|
||||||
|
|
||||||
import psycopg2
|
import psycopg2
|
||||||
from numpy.typing import NDArray
|
from numpy.typing import NDArray
|
||||||
from psycopg2 import sql
|
from psycopg2 import sql
|
||||||
from psycopg2.extras import execute_values, Json
|
from psycopg2.extras import execute_values, Json
|
||||||
from pydantic import BaseModel
|
|
||||||
from llama_stack.apis.memory import * # noqa: F403
|
|
||||||
|
|
||||||
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
from llama_stack.apis.memory import * # noqa: F403
|
||||||
|
from llama_stack.distribution.datatypes import RoutableProvider
|
||||||
|
|
||||||
from llama_stack.providers.utils.memory.vector_store import (
|
from llama_stack.providers.utils.memory.vector_store import (
|
||||||
ALL_MINILM_L6_V2_DIMENSION,
|
ALL_MINILM_L6_V2_DIMENSION,
|
||||||
|
@ -118,7 +119,7 @@ class PGVectorIndex(EmbeddingIndex):
|
||||||
return QueryDocumentsResponse(chunks=chunks, scores=scores)
|
return QueryDocumentsResponse(chunks=chunks, scores=scores)
|
||||||
|
|
||||||
|
|
||||||
class PGVectorMemoryAdapter(Memory):
|
class PGVectorMemoryAdapter(Memory, RoutableProvider):
|
||||||
def __init__(self, config: PGVectorConfig) -> None:
|
def __init__(self, config: PGVectorConfig) -> None:
|
||||||
print(f"Initializing PGVectorMemoryAdapter -> {config.host}:{config.port}")
|
print(f"Initializing PGVectorMemoryAdapter -> {config.host}:{config.port}")
|
||||||
self.config = config
|
self.config = config
|
||||||
|
@ -160,6 +161,13 @@ class PGVectorMemoryAdapter(Memory):
|
||||||
async def shutdown(self) -> None:
|
async def shutdown(self) -> None:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
async def register_routing_keys(self, routing_keys: List[str]) -> None:
|
||||||
|
print(f"[pgvector] Registering memory bank routing keys: {routing_keys}")
|
||||||
|
self.routing_keys = routing_keys
|
||||||
|
|
||||||
|
def get_routing_keys(self) -> List[str]:
|
||||||
|
return self.routing_keys
|
||||||
|
|
||||||
async def create_memory_bank(
|
async def create_memory_bank(
|
||||||
self,
|
self,
|
||||||
name: str,
|
name: str,
|
||||||
|
|
|
@ -4,47 +4,63 @@
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
|
||||||
import traceback
|
import traceback
|
||||||
from typing import Any, Dict, List
|
from typing import Any, Dict, List
|
||||||
|
|
||||||
from .config import BedrockSafetyConfig
|
import boto3
|
||||||
|
|
||||||
from llama_stack.apis.safety import * # noqa
|
from llama_stack.apis.safety import * # noqa
|
||||||
from llama_models.llama3.api.datatypes import * # noqa: F403
|
from llama_models.llama3.api.datatypes import * # noqa: F403
|
||||||
import json
|
from llama_stack.distribution.datatypes import RoutableProvider
|
||||||
import logging
|
|
||||||
|
|
||||||
import boto3
|
from .config import BedrockSafetyConfig
|
||||||
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class BedrockSafetyAdapter(Safety):
|
SUPPORTED_SHIELD_TYPES = [
|
||||||
|
"bedrock_guardrail",
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
class BedrockSafetyAdapter(Safety, RoutableProvider):
|
||||||
def __init__(self, config: BedrockSafetyConfig) -> None:
|
def __init__(self, config: BedrockSafetyConfig) -> None:
|
||||||
|
if not config.aws_profile:
|
||||||
|
raise ValueError(f"Missing boto_client aws_profile in model info::{config}")
|
||||||
self.config = config
|
self.config = config
|
||||||
|
|
||||||
async def initialize(self) -> None:
|
async def initialize(self) -> None:
|
||||||
if not self.config.aws_profile:
|
|
||||||
raise RuntimeError(
|
|
||||||
f"Missing boto_client aws_profile in model info::{self.config}"
|
|
||||||
)
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
print(f"initializing with profile --- > {self.config}::")
|
print(f"initializing with profile --- > {self.config}")
|
||||||
self.boto_client_profile = self.config.aws_profile
|
|
||||||
self.boto_client = boto3.Session(
|
self.boto_client = boto3.Session(
|
||||||
profile_name=self.boto_client_profile
|
profile_name=self.config.aws_profile
|
||||||
).client("bedrock-runtime")
|
).client("bedrock-runtime")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise RuntimeError(f"Error initializing BedrockSafetyAdapter: {e}") from e
|
raise RuntimeError("Error initializing BedrockSafetyAdapter") from e
|
||||||
|
|
||||||
async def shutdown(self) -> None:
|
async def shutdown(self) -> None:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
async def register_routing_keys(self, routing_keys: List[str]) -> None:
|
||||||
|
for key in routing_keys:
|
||||||
|
if key not in SUPPORTED_SHIELD_TYPES:
|
||||||
|
raise ValueError(f"Unknown safety shield type: {key}")
|
||||||
|
|
||||||
|
self.routing_keys = routing_keys
|
||||||
|
|
||||||
|
def get_routing_keys(self) -> List[str]:
|
||||||
|
return self.routing_keys
|
||||||
|
|
||||||
async def run_shield(
|
async def run_shield(
|
||||||
self, shield_type: str, messages: List[Message], params: Dict[str, Any] = None
|
self, shield_type: str, messages: List[Message], params: Dict[str, Any] = None
|
||||||
) -> RunShieldResponse:
|
) -> RunShieldResponse:
|
||||||
|
if shield_type not in SUPPORTED_SHIELD_TYPES:
|
||||||
|
raise ValueError(f"Unknown safety shield type: {shield_type}")
|
||||||
|
|
||||||
"""This is the implementation for the bedrock guardrails. The input to the guardrails is to be of this format
|
"""This is the implementation for the bedrock guardrails. The input to the guardrails is to be of this format
|
||||||
```content = [
|
```content = [
|
||||||
{
|
{
|
||||||
|
|
|
@ -3,7 +3,6 @@
|
||||||
#
|
#
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
from llama_models.sku_list import resolve_model
|
|
||||||
from together import Together
|
from together import Together
|
||||||
|
|
||||||
from llama_models.llama3.api.datatypes import * # noqa: F403
|
from llama_models.llama3.api.datatypes import * # noqa: F403
|
||||||
|
@ -13,43 +12,43 @@ from llama_stack.apis.safety import (
|
||||||
SafetyViolation,
|
SafetyViolation,
|
||||||
ViolationLevel,
|
ViolationLevel,
|
||||||
)
|
)
|
||||||
|
from llama_stack.distribution.datatypes import RoutableProvider
|
||||||
from llama_stack.distribution.request_headers import NeedsRequestProviderData
|
from llama_stack.distribution.request_headers import NeedsRequestProviderData
|
||||||
|
|
||||||
from .config import TogetherSafetyConfig
|
from .config import TogetherSafetyConfig
|
||||||
|
|
||||||
|
|
||||||
SAFETY_SHIELD_TYPES = {
|
SAFETY_SHIELD_TYPES = {
|
||||||
|
"llama_guard": "meta-llama/Meta-Llama-Guard-3-8B",
|
||||||
"Llama-Guard-3-8B": "meta-llama/Meta-Llama-Guard-3-8B",
|
"Llama-Guard-3-8B": "meta-llama/Meta-Llama-Guard-3-8B",
|
||||||
"Llama-Guard-3-11B-Vision": "meta-llama/Llama-Guard-3-11B-Vision-Turbo",
|
"Llama-Guard-3-11B-Vision": "meta-llama/Llama-Guard-3-11B-Vision-Turbo",
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
def shield_type_to_model_name(shield_type: str) -> str:
|
class TogetherSafetyImpl(Safety, NeedsRequestProviderData, RoutableProvider):
|
||||||
if shield_type == "llama_guard":
|
|
||||||
shield_type = "Llama-Guard-3-8B"
|
|
||||||
|
|
||||||
model = resolve_model(shield_type)
|
|
||||||
if (
|
|
||||||
model is None
|
|
||||||
or not model.descriptor(shorten_default_variant=True) in SAFETY_SHIELD_TYPES
|
|
||||||
or model.model_family is not ModelFamily.safety
|
|
||||||
):
|
|
||||||
raise ValueError(
|
|
||||||
f"{shield_type} is not supported, please use of {','.join(SAFETY_SHIELD_TYPES.keys())}"
|
|
||||||
)
|
|
||||||
|
|
||||||
return SAFETY_SHIELD_TYPES.get(model.descriptor(shorten_default_variant=True))
|
|
||||||
|
|
||||||
|
|
||||||
class TogetherSafetyImpl(Safety, NeedsRequestProviderData):
|
|
||||||
def __init__(self, config: TogetherSafetyConfig) -> None:
|
def __init__(self, config: TogetherSafetyConfig) -> None:
|
||||||
self.config = config
|
self.config = config
|
||||||
|
|
||||||
async def initialize(self) -> None:
|
async def initialize(self) -> None:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
async def shutdown(self) -> None:
|
||||||
|
pass
|
||||||
|
|
||||||
|
async def register_routing_keys(self, routing_keys: List[str]) -> None:
|
||||||
|
for key in routing_keys:
|
||||||
|
if key not in SAFETY_SHIELD_TYPES:
|
||||||
|
raise ValueError(f"Unknown safety shield type: {key}")
|
||||||
|
self.routing_keys = routing_keys
|
||||||
|
|
||||||
|
def get_routing_keys(self) -> List[str]:
|
||||||
|
return self.routing_keys
|
||||||
|
|
||||||
async def run_shield(
|
async def run_shield(
|
||||||
self, shield_type: str, messages: List[Message], params: Dict[str, Any] = None
|
self, shield_type: str, messages: List[Message], params: Dict[str, Any] = None
|
||||||
) -> RunShieldResponse:
|
) -> RunShieldResponse:
|
||||||
|
if shield_type not in SAFETY_SHIELD_TYPES:
|
||||||
|
raise ValueError(f"Unknown safety shield type: {shield_type}")
|
||||||
|
|
||||||
together_api_key = None
|
together_api_key = None
|
||||||
provider_data = self.get_request_provider_data()
|
provider_data = self.get_request_provider_data()
|
||||||
|
@ -59,7 +58,7 @@ class TogetherSafetyImpl(Safety, NeedsRequestProviderData):
|
||||||
)
|
)
|
||||||
together_api_key = provider_data.together_api_key
|
together_api_key = provider_data.together_api_key
|
||||||
|
|
||||||
model_name = shield_type_to_model_name(shield_type)
|
model_name = SAFETY_SHIELD_TYPES[shield_type]
|
||||||
|
|
||||||
# messages can have role assistant or user
|
# messages can have role assistant or user
|
||||||
api_messages = []
|
api_messages = []
|
||||||
|
|
|
@ -14,6 +14,7 @@ import numpy as np
|
||||||
from numpy.typing import NDArray
|
from numpy.typing import NDArray
|
||||||
|
|
||||||
from llama_models.llama3.api.datatypes import * # noqa: F403
|
from llama_models.llama3.api.datatypes import * # noqa: F403
|
||||||
|
from llama_stack.distribution.datatypes import RoutableProvider
|
||||||
|
|
||||||
from llama_stack.apis.memory import * # noqa: F403
|
from llama_stack.apis.memory import * # noqa: F403
|
||||||
from llama_stack.providers.utils.memory.vector_store import (
|
from llama_stack.providers.utils.memory.vector_store import (
|
||||||
|
@ -62,7 +63,7 @@ class FaissIndex(EmbeddingIndex):
|
||||||
return QueryDocumentsResponse(chunks=chunks, scores=scores)
|
return QueryDocumentsResponse(chunks=chunks, scores=scores)
|
||||||
|
|
||||||
|
|
||||||
class FaissMemoryImpl(Memory):
|
class FaissMemoryImpl(Memory, RoutableProvider):
|
||||||
def __init__(self, config: FaissImplConfig) -> None:
|
def __init__(self, config: FaissImplConfig) -> None:
|
||||||
self.config = config
|
self.config = config
|
||||||
self.cache = {}
|
self.cache = {}
|
||||||
|
@ -71,6 +72,13 @@ class FaissMemoryImpl(Memory):
|
||||||
|
|
||||||
async def shutdown(self) -> None: ...
|
async def shutdown(self) -> None: ...
|
||||||
|
|
||||||
|
async def register_routing_keys(self, routing_keys: List[str]) -> None:
|
||||||
|
print(f"[faiss] Registering memory bank routing keys: {routing_keys}")
|
||||||
|
self.routing_keys = routing_keys
|
||||||
|
|
||||||
|
def get_routing_keys(self) -> List[str]:
|
||||||
|
return self.routing_keys
|
||||||
|
|
||||||
async def create_memory_bank(
|
async def create_memory_bank(
|
||||||
self,
|
self,
|
||||||
name: str,
|
name: str,
|
||||||
|
|
|
@ -4,13 +4,15 @@
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
from typing import Any, Dict, List
|
||||||
|
|
||||||
from llama_models.sku_list import resolve_model
|
from llama_models.sku_list import resolve_model
|
||||||
|
|
||||||
from llama_stack.distribution.utils.model_utils import model_local_dir
|
from llama_stack.distribution.utils.model_utils import model_local_dir
|
||||||
from llama_stack.apis.inference import * # noqa: F403
|
from llama_stack.apis.inference import * # noqa: F403
|
||||||
from llama_stack.apis.safety import * # noqa: F403
|
from llama_stack.apis.safety import * # noqa: F403
|
||||||
from llama_models.llama3.api.datatypes import * # noqa: F403
|
from llama_models.llama3.api.datatypes import * # noqa: F403
|
||||||
from llama_stack.distribution.datatypes import Api
|
from llama_stack.distribution.datatypes import Api, RoutableProvider
|
||||||
|
|
||||||
from llama_stack.providers.impls.meta_reference.safety.shields.base import (
|
from llama_stack.providers.impls.meta_reference.safety.shields.base import (
|
||||||
OnViolationAction,
|
OnViolationAction,
|
||||||
|
@ -35,7 +37,7 @@ def resolve_and_get_path(model_name: str) -> str:
|
||||||
return model_dir
|
return model_dir
|
||||||
|
|
||||||
|
|
||||||
class MetaReferenceSafetyImpl(Safety):
|
class MetaReferenceSafetyImpl(Safety, RoutableProvider):
|
||||||
def __init__(self, config: SafetyConfig, deps) -> None:
|
def __init__(self, config: SafetyConfig, deps) -> None:
|
||||||
self.config = config
|
self.config = config
|
||||||
self.inference_api = deps[Api.inference]
|
self.inference_api = deps[Api.inference]
|
||||||
|
@ -46,6 +48,19 @@ class MetaReferenceSafetyImpl(Safety):
|
||||||
model_dir = resolve_and_get_path(shield_cfg.model)
|
model_dir = resolve_and_get_path(shield_cfg.model)
|
||||||
_ = PromptGuardShield.instance(model_dir)
|
_ = PromptGuardShield.instance(model_dir)
|
||||||
|
|
||||||
|
async def shutdown(self) -> None:
|
||||||
|
pass
|
||||||
|
|
||||||
|
async def register_routing_keys(self, routing_keys: List[str]) -> None:
|
||||||
|
available_shields = [v.value for v in MetaReferenceShieldType]
|
||||||
|
for key in routing_keys:
|
||||||
|
if key not in available_shields:
|
||||||
|
raise ValueError(f"Unknown safety shield type: {key}")
|
||||||
|
self.routing_keys = routing_keys
|
||||||
|
|
||||||
|
def get_routing_keys(self) -> List[str]:
|
||||||
|
return self.routing_keys
|
||||||
|
|
||||||
async def run_shield(
|
async def run_shield(
|
||||||
self,
|
self,
|
||||||
shield_type: str,
|
shield_type: str,
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue