add more RoutableProviders

This commit is contained in:
Ashwin Bharambe 2024-09-30 15:51:45 -07:00
parent c17c17cb19
commit 878b2c31c7
7 changed files with 100 additions and 43 deletions

View file

@ -40,6 +40,10 @@ class CommonRoutingTableImpl(RoutingTable):
async def initialize(self) -> None:
for keys, p in self.unique_providers:
spec = p.__provider_spec__
if isinstance(spec, RemoteProviderSpec) and spec.adapter is None:
continue
await p.register_routing_keys(keys)
async def shutdown(self) -> None:

View file

@ -13,7 +13,7 @@ import chromadb
from numpy.typing import NDArray
from llama_stack.apis.memory import * # noqa: F403
from llama_stack.distribution.datatypes import RoutableProvider
from llama_stack.providers.utils.memory.vector_store import (
BankWithIndex,
@ -65,7 +65,7 @@ class ChromaIndex(EmbeddingIndex):
return QueryDocumentsResponse(chunks=chunks, scores=scores)
class ChromaMemoryAdapter(Memory):
class ChromaMemoryAdapter(Memory, RoutableProvider):
def __init__(self, url: str) -> None:
print(f"Initializing ChromaMemoryAdapter with url: {url}")
url = url.rstrip("/")
@ -93,6 +93,13 @@ class ChromaMemoryAdapter(Memory):
async def shutdown(self) -> None:
pass
async def register_routing_keys(self, routing_keys: List[str]) -> None:
print(f"[chroma] Registering memory bank routing keys: {routing_keys}")
self.routing_keys = routing_keys
def get_routing_keys(self) -> List[str]:
return self.routing_keys
async def create_memory_bank(
self,
name: str,

View file

@ -5,16 +5,17 @@
# the root directory of this source tree.
import uuid
from typing import List, Tuple
import psycopg2
from numpy.typing import NDArray
from psycopg2 import sql
from psycopg2.extras import execute_values, Json
from pydantic import BaseModel
from llama_stack.apis.memory import * # noqa: F403
from pydantic import BaseModel
from llama_stack.apis.memory import * # noqa: F403
from llama_stack.distribution.datatypes import RoutableProvider
from llama_stack.providers.utils.memory.vector_store import (
ALL_MINILM_L6_V2_DIMENSION,
@ -118,7 +119,7 @@ class PGVectorIndex(EmbeddingIndex):
return QueryDocumentsResponse(chunks=chunks, scores=scores)
class PGVectorMemoryAdapter(Memory):
class PGVectorMemoryAdapter(Memory, RoutableProvider):
def __init__(self, config: PGVectorConfig) -> None:
print(f"Initializing PGVectorMemoryAdapter -> {config.host}:{config.port}")
self.config = config
@ -160,6 +161,13 @@ class PGVectorMemoryAdapter(Memory):
async def shutdown(self) -> None:
pass
async def register_routing_keys(self, routing_keys: List[str]) -> None:
print(f"[pgvector] Registering memory bank routing keys: {routing_keys}")
self.routing_keys = routing_keys
def get_routing_keys(self) -> List[str]:
return self.routing_keys
async def create_memory_bank(
self,
name: str,

View file

@ -4,47 +4,63 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import json
import logging
import traceback
from typing import Any, Dict, List
from .config import BedrockSafetyConfig
import boto3
from llama_stack.apis.safety import * # noqa
from llama_models.llama3.api.datatypes import * # noqa: F403
import json
import logging
from llama_stack.distribution.datatypes import RoutableProvider
import boto3
from .config import BedrockSafetyConfig
logger = logging.getLogger(__name__)
class BedrockSafetyAdapter(Safety):
SUPPORTED_SHIELD_TYPES = [
"bedrock_guardrail",
]
class BedrockSafetyAdapter(Safety, RoutableProvider):
def __init__(self, config: BedrockSafetyConfig) -> None:
if not config.aws_profile:
raise ValueError(f"Missing boto_client aws_profile in model info::{config}")
self.config = config
async def initialize(self) -> None:
if not self.config.aws_profile:
raise RuntimeError(
f"Missing boto_client aws_profile in model info::{self.config}"
)
try:
print(f"initializing with profile --- > {self.config}::")
self.boto_client_profile = self.config.aws_profile
print(f"initializing with profile --- > {self.config}")
self.boto_client = boto3.Session(
profile_name=self.boto_client_profile
profile_name=self.config.aws_profile
).client("bedrock-runtime")
except Exception as e:
raise RuntimeError(f"Error initializing BedrockSafetyAdapter: {e}") from e
raise RuntimeError("Error initializing BedrockSafetyAdapter") from e
async def shutdown(self) -> None:
pass
async def register_routing_keys(self, routing_keys: List[str]) -> None:
for key in routing_keys:
if key not in SUPPORTED_SHIELD_TYPES:
raise ValueError(f"Unknown safety shield type: {key}")
self.routing_keys = routing_keys
def get_routing_keys(self) -> List[str]:
return self.routing_keys
async def run_shield(
self, shield_type: str, messages: List[Message], params: Dict[str, Any] = None
) -> RunShieldResponse:
if shield_type not in SUPPORTED_SHIELD_TYPES:
raise ValueError(f"Unknown safety shield type: {shield_type}")
"""This is the implementation for the bedrock guardrails. The input to the guardrails is to be of this format
```content = [
{

View file

@ -3,7 +3,6 @@
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from llama_models.sku_list import resolve_model
from together import Together
from llama_models.llama3.api.datatypes import * # noqa: F403
@ -13,43 +12,43 @@ from llama_stack.apis.safety import (
SafetyViolation,
ViolationLevel,
)
from llama_stack.distribution.datatypes import RoutableProvider
from llama_stack.distribution.request_headers import NeedsRequestProviderData
from .config import TogetherSafetyConfig
SAFETY_SHIELD_TYPES = {
"llama_guard": "meta-llama/Meta-Llama-Guard-3-8B",
"Llama-Guard-3-8B": "meta-llama/Meta-Llama-Guard-3-8B",
"Llama-Guard-3-11B-Vision": "meta-llama/Llama-Guard-3-11B-Vision-Turbo",
}
def shield_type_to_model_name(shield_type: str) -> str:
if shield_type == "llama_guard":
shield_type = "Llama-Guard-3-8B"
model = resolve_model(shield_type)
if (
model is None
or not model.descriptor(shorten_default_variant=True) in SAFETY_SHIELD_TYPES
or model.model_family is not ModelFamily.safety
):
raise ValueError(
f"{shield_type} is not supported, please use of {','.join(SAFETY_SHIELD_TYPES.keys())}"
)
return SAFETY_SHIELD_TYPES.get(model.descriptor(shorten_default_variant=True))
class TogetherSafetyImpl(Safety, NeedsRequestProviderData):
class TogetherSafetyImpl(Safety, NeedsRequestProviderData, RoutableProvider):
def __init__(self, config: TogetherSafetyConfig) -> None:
self.config = config
async def initialize(self) -> None:
pass
async def shutdown(self) -> None:
pass
async def register_routing_keys(self, routing_keys: List[str]) -> None:
for key in routing_keys:
if key not in SAFETY_SHIELD_TYPES:
raise ValueError(f"Unknown safety shield type: {key}")
self.routing_keys = routing_keys
def get_routing_keys(self) -> List[str]:
return self.routing_keys
async def run_shield(
self, shield_type: str, messages: List[Message], params: Dict[str, Any] = None
) -> RunShieldResponse:
if shield_type not in SAFETY_SHIELD_TYPES:
raise ValueError(f"Unknown safety shield type: {shield_type}")
together_api_key = None
provider_data = self.get_request_provider_data()
@ -59,7 +58,7 @@ class TogetherSafetyImpl(Safety, NeedsRequestProviderData):
)
together_api_key = provider_data.together_api_key
model_name = shield_type_to_model_name(shield_type)
model_name = SAFETY_SHIELD_TYPES[shield_type]
# messages can have role assistant or user
api_messages = []

View file

@ -14,6 +14,7 @@ import numpy as np
from numpy.typing import NDArray
from llama_models.llama3.api.datatypes import * # noqa: F403
from llama_stack.distribution.datatypes import RoutableProvider
from llama_stack.apis.memory import * # noqa: F403
from llama_stack.providers.utils.memory.vector_store import (
@ -62,7 +63,7 @@ class FaissIndex(EmbeddingIndex):
return QueryDocumentsResponse(chunks=chunks, scores=scores)
class FaissMemoryImpl(Memory):
class FaissMemoryImpl(Memory, RoutableProvider):
def __init__(self, config: FaissImplConfig) -> None:
self.config = config
self.cache = {}
@ -71,6 +72,13 @@ class FaissMemoryImpl(Memory):
async def shutdown(self) -> None: ...
async def register_routing_keys(self, routing_keys: List[str]) -> None:
print(f"[faiss] Registering memory bank routing keys: {routing_keys}")
self.routing_keys = routing_keys
def get_routing_keys(self) -> List[str]:
return self.routing_keys
async def create_memory_bank(
self,
name: str,

View file

@ -4,13 +4,15 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Any, Dict, List
from llama_models.sku_list import resolve_model
from llama_stack.distribution.utils.model_utils import model_local_dir
from llama_stack.apis.inference import * # noqa: F403
from llama_stack.apis.safety import * # noqa: F403
from llama_models.llama3.api.datatypes import * # noqa: F403
from llama_stack.distribution.datatypes import Api
from llama_stack.distribution.datatypes import Api, RoutableProvider
from llama_stack.providers.impls.meta_reference.safety.shields.base import (
OnViolationAction,
@ -35,7 +37,7 @@ def resolve_and_get_path(model_name: str) -> str:
return model_dir
class MetaReferenceSafetyImpl(Safety):
class MetaReferenceSafetyImpl(Safety, RoutableProvider):
def __init__(self, config: SafetyConfig, deps) -> None:
self.config = config
self.inference_api = deps[Api.inference]
@ -46,6 +48,19 @@ class MetaReferenceSafetyImpl(Safety):
model_dir = resolve_and_get_path(shield_cfg.model)
_ = PromptGuardShield.instance(model_dir)
async def shutdown(self) -> None:
pass
async def register_routing_keys(self, routing_keys: List[str]) -> None:
available_shields = [v.value for v in MetaReferenceShieldType]
for key in routing_keys:
if key not in available_shields:
raise ValueError(f"Unknown safety shield type: {key}")
self.routing_keys = routing_keys
def get_routing_keys(self) -> List[str]:
return self.routing_keys
async def run_shield(
self,
shield_type: str,