add more RoutableProviders

This commit is contained in:
Ashwin Bharambe 2024-09-30 15:51:45 -07:00
parent c17c17cb19
commit 878b2c31c7
7 changed files with 100 additions and 43 deletions

View file

@ -40,6 +40,10 @@ class CommonRoutingTableImpl(RoutingTable):
async def initialize(self) -> None: async def initialize(self) -> None:
for keys, p in self.unique_providers: for keys, p in self.unique_providers:
spec = p.__provider_spec__
if isinstance(spec, RemoteProviderSpec) and spec.adapter is None:
continue
await p.register_routing_keys(keys) await p.register_routing_keys(keys)
async def shutdown(self) -> None: async def shutdown(self) -> None:

View file

@ -13,7 +13,7 @@ import chromadb
from numpy.typing import NDArray from numpy.typing import NDArray
from llama_stack.apis.memory import * # noqa: F403 from llama_stack.apis.memory import * # noqa: F403
from llama_stack.distribution.datatypes import RoutableProvider
from llama_stack.providers.utils.memory.vector_store import ( from llama_stack.providers.utils.memory.vector_store import (
BankWithIndex, BankWithIndex,
@ -65,7 +65,7 @@ class ChromaIndex(EmbeddingIndex):
return QueryDocumentsResponse(chunks=chunks, scores=scores) return QueryDocumentsResponse(chunks=chunks, scores=scores)
class ChromaMemoryAdapter(Memory): class ChromaMemoryAdapter(Memory, RoutableProvider):
def __init__(self, url: str) -> None: def __init__(self, url: str) -> None:
print(f"Initializing ChromaMemoryAdapter with url: {url}") print(f"Initializing ChromaMemoryAdapter with url: {url}")
url = url.rstrip("/") url = url.rstrip("/")
@ -93,6 +93,13 @@ class ChromaMemoryAdapter(Memory):
async def shutdown(self) -> None: async def shutdown(self) -> None:
pass pass
async def register_routing_keys(self, routing_keys: List[str]) -> None:
print(f"[chroma] Registering memory bank routing keys: {routing_keys}")
self.routing_keys = routing_keys
def get_routing_keys(self) -> List[str]:
return self.routing_keys
async def create_memory_bank( async def create_memory_bank(
self, self,
name: str, name: str,

View file

@ -5,16 +5,17 @@
# the root directory of this source tree. # the root directory of this source tree.
import uuid import uuid
from typing import List, Tuple from typing import List, Tuple
import psycopg2 import psycopg2
from numpy.typing import NDArray from numpy.typing import NDArray
from psycopg2 import sql from psycopg2 import sql
from psycopg2.extras import execute_values, Json from psycopg2.extras import execute_values, Json
from pydantic import BaseModel
from llama_stack.apis.memory import * # noqa: F403
from pydantic import BaseModel
from llama_stack.apis.memory import * # noqa: F403
from llama_stack.distribution.datatypes import RoutableProvider
from llama_stack.providers.utils.memory.vector_store import ( from llama_stack.providers.utils.memory.vector_store import (
ALL_MINILM_L6_V2_DIMENSION, ALL_MINILM_L6_V2_DIMENSION,
@ -118,7 +119,7 @@ class PGVectorIndex(EmbeddingIndex):
return QueryDocumentsResponse(chunks=chunks, scores=scores) return QueryDocumentsResponse(chunks=chunks, scores=scores)
class PGVectorMemoryAdapter(Memory): class PGVectorMemoryAdapter(Memory, RoutableProvider):
def __init__(self, config: PGVectorConfig) -> None: def __init__(self, config: PGVectorConfig) -> None:
print(f"Initializing PGVectorMemoryAdapter -> {config.host}:{config.port}") print(f"Initializing PGVectorMemoryAdapter -> {config.host}:{config.port}")
self.config = config self.config = config
@ -160,6 +161,13 @@ class PGVectorMemoryAdapter(Memory):
async def shutdown(self) -> None: async def shutdown(self) -> None:
pass pass
async def register_routing_keys(self, routing_keys: List[str]) -> None:
print(f"[pgvector] Registering memory bank routing keys: {routing_keys}")
self.routing_keys = routing_keys
def get_routing_keys(self) -> List[str]:
return self.routing_keys
async def create_memory_bank( async def create_memory_bank(
self, self,
name: str, name: str,

View file

@ -4,47 +4,63 @@
# This source code is licensed under the terms described in the LICENSE file in # This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree. # the root directory of this source tree.
import json
import logging
import traceback import traceback
from typing import Any, Dict, List from typing import Any, Dict, List
from .config import BedrockSafetyConfig import boto3
from llama_stack.apis.safety import * # noqa from llama_stack.apis.safety import * # noqa
from llama_models.llama3.api.datatypes import * # noqa: F403 from llama_models.llama3.api.datatypes import * # noqa: F403
import json from llama_stack.distribution.datatypes import RoutableProvider
import logging
import boto3 from .config import BedrockSafetyConfig
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class BedrockSafetyAdapter(Safety): SUPPORTED_SHIELD_TYPES = [
"bedrock_guardrail",
]
class BedrockSafetyAdapter(Safety, RoutableProvider):
def __init__(self, config: BedrockSafetyConfig) -> None: def __init__(self, config: BedrockSafetyConfig) -> None:
if not config.aws_profile:
raise ValueError(f"Missing boto_client aws_profile in model info::{config}")
self.config = config self.config = config
async def initialize(self) -> None: async def initialize(self) -> None:
if not self.config.aws_profile:
raise RuntimeError(
f"Missing boto_client aws_profile in model info::{self.config}"
)
try: try:
print(f"initializing with profile --- > {self.config}::") print(f"initializing with profile --- > {self.config}")
self.boto_client_profile = self.config.aws_profile
self.boto_client = boto3.Session( self.boto_client = boto3.Session(
profile_name=self.boto_client_profile profile_name=self.config.aws_profile
).client("bedrock-runtime") ).client("bedrock-runtime")
except Exception as e: except Exception as e:
raise RuntimeError(f"Error initializing BedrockSafetyAdapter: {e}") from e raise RuntimeError("Error initializing BedrockSafetyAdapter") from e
async def shutdown(self) -> None: async def shutdown(self) -> None:
pass pass
async def register_routing_keys(self, routing_keys: List[str]) -> None:
for key in routing_keys:
if key not in SUPPORTED_SHIELD_TYPES:
raise ValueError(f"Unknown safety shield type: {key}")
self.routing_keys = routing_keys
def get_routing_keys(self) -> List[str]:
return self.routing_keys
async def run_shield( async def run_shield(
self, shield_type: str, messages: List[Message], params: Dict[str, Any] = None self, shield_type: str, messages: List[Message], params: Dict[str, Any] = None
) -> RunShieldResponse: ) -> RunShieldResponse:
if shield_type not in SUPPORTED_SHIELD_TYPES:
raise ValueError(f"Unknown safety shield type: {shield_type}")
"""This is the implementation for the bedrock guardrails. The input to the guardrails is to be of this format """This is the implementation for the bedrock guardrails. The input to the guardrails is to be of this format
```content = [ ```content = [
{ {

View file

@ -3,7 +3,6 @@
# #
# This source code is licensed under the terms described in the LICENSE file in # This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree. # the root directory of this source tree.
from llama_models.sku_list import resolve_model
from together import Together from together import Together
from llama_models.llama3.api.datatypes import * # noqa: F403 from llama_models.llama3.api.datatypes import * # noqa: F403
@ -13,43 +12,43 @@ from llama_stack.apis.safety import (
SafetyViolation, SafetyViolation,
ViolationLevel, ViolationLevel,
) )
from llama_stack.distribution.datatypes import RoutableProvider
from llama_stack.distribution.request_headers import NeedsRequestProviderData from llama_stack.distribution.request_headers import NeedsRequestProviderData
from .config import TogetherSafetyConfig from .config import TogetherSafetyConfig
SAFETY_SHIELD_TYPES = { SAFETY_SHIELD_TYPES = {
"llama_guard": "meta-llama/Meta-Llama-Guard-3-8B",
"Llama-Guard-3-8B": "meta-llama/Meta-Llama-Guard-3-8B", "Llama-Guard-3-8B": "meta-llama/Meta-Llama-Guard-3-8B",
"Llama-Guard-3-11B-Vision": "meta-llama/Llama-Guard-3-11B-Vision-Turbo", "Llama-Guard-3-11B-Vision": "meta-llama/Llama-Guard-3-11B-Vision-Turbo",
} }
def shield_type_to_model_name(shield_type: str) -> str: class TogetherSafetyImpl(Safety, NeedsRequestProviderData, RoutableProvider):
if shield_type == "llama_guard":
shield_type = "Llama-Guard-3-8B"
model = resolve_model(shield_type)
if (
model is None
or not model.descriptor(shorten_default_variant=True) in SAFETY_SHIELD_TYPES
or model.model_family is not ModelFamily.safety
):
raise ValueError(
f"{shield_type} is not supported, please use of {','.join(SAFETY_SHIELD_TYPES.keys())}"
)
return SAFETY_SHIELD_TYPES.get(model.descriptor(shorten_default_variant=True))
class TogetherSafetyImpl(Safety, NeedsRequestProviderData):
def __init__(self, config: TogetherSafetyConfig) -> None: def __init__(self, config: TogetherSafetyConfig) -> None:
self.config = config self.config = config
async def initialize(self) -> None: async def initialize(self) -> None:
pass pass
async def shutdown(self) -> None:
pass
async def register_routing_keys(self, routing_keys: List[str]) -> None:
for key in routing_keys:
if key not in SAFETY_SHIELD_TYPES:
raise ValueError(f"Unknown safety shield type: {key}")
self.routing_keys = routing_keys
def get_routing_keys(self) -> List[str]:
return self.routing_keys
async def run_shield( async def run_shield(
self, shield_type: str, messages: List[Message], params: Dict[str, Any] = None self, shield_type: str, messages: List[Message], params: Dict[str, Any] = None
) -> RunShieldResponse: ) -> RunShieldResponse:
if shield_type not in SAFETY_SHIELD_TYPES:
raise ValueError(f"Unknown safety shield type: {shield_type}")
together_api_key = None together_api_key = None
provider_data = self.get_request_provider_data() provider_data = self.get_request_provider_data()
@ -59,7 +58,7 @@ class TogetherSafetyImpl(Safety, NeedsRequestProviderData):
) )
together_api_key = provider_data.together_api_key together_api_key = provider_data.together_api_key
model_name = shield_type_to_model_name(shield_type) model_name = SAFETY_SHIELD_TYPES[shield_type]
# messages can have role assistant or user # messages can have role assistant or user
api_messages = [] api_messages = []

View file

@ -14,6 +14,7 @@ import numpy as np
from numpy.typing import NDArray from numpy.typing import NDArray
from llama_models.llama3.api.datatypes import * # noqa: F403 from llama_models.llama3.api.datatypes import * # noqa: F403
from llama_stack.distribution.datatypes import RoutableProvider
from llama_stack.apis.memory import * # noqa: F403 from llama_stack.apis.memory import * # noqa: F403
from llama_stack.providers.utils.memory.vector_store import ( from llama_stack.providers.utils.memory.vector_store import (
@ -62,7 +63,7 @@ class FaissIndex(EmbeddingIndex):
return QueryDocumentsResponse(chunks=chunks, scores=scores) return QueryDocumentsResponse(chunks=chunks, scores=scores)
class FaissMemoryImpl(Memory): class FaissMemoryImpl(Memory, RoutableProvider):
def __init__(self, config: FaissImplConfig) -> None: def __init__(self, config: FaissImplConfig) -> None:
self.config = config self.config = config
self.cache = {} self.cache = {}
@ -71,6 +72,13 @@ class FaissMemoryImpl(Memory):
async def shutdown(self) -> None: ... async def shutdown(self) -> None: ...
async def register_routing_keys(self, routing_keys: List[str]) -> None:
print(f"[faiss] Registering memory bank routing keys: {routing_keys}")
self.routing_keys = routing_keys
def get_routing_keys(self) -> List[str]:
return self.routing_keys
async def create_memory_bank( async def create_memory_bank(
self, self,
name: str, name: str,

View file

@ -4,13 +4,15 @@
# This source code is licensed under the terms described in the LICENSE file in # This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree. # the root directory of this source tree.
from typing import Any, Dict, List
from llama_models.sku_list import resolve_model from llama_models.sku_list import resolve_model
from llama_stack.distribution.utils.model_utils import model_local_dir from llama_stack.distribution.utils.model_utils import model_local_dir
from llama_stack.apis.inference import * # noqa: F403 from llama_stack.apis.inference import * # noqa: F403
from llama_stack.apis.safety import * # noqa: F403 from llama_stack.apis.safety import * # noqa: F403
from llama_models.llama3.api.datatypes import * # noqa: F403 from llama_models.llama3.api.datatypes import * # noqa: F403
from llama_stack.distribution.datatypes import Api from llama_stack.distribution.datatypes import Api, RoutableProvider
from llama_stack.providers.impls.meta_reference.safety.shields.base import ( from llama_stack.providers.impls.meta_reference.safety.shields.base import (
OnViolationAction, OnViolationAction,
@ -35,7 +37,7 @@ def resolve_and_get_path(model_name: str) -> str:
return model_dir return model_dir
class MetaReferenceSafetyImpl(Safety): class MetaReferenceSafetyImpl(Safety, RoutableProvider):
def __init__(self, config: SafetyConfig, deps) -> None: def __init__(self, config: SafetyConfig, deps) -> None:
self.config = config self.config = config
self.inference_api = deps[Api.inference] self.inference_api = deps[Api.inference]
@ -46,6 +48,19 @@ class MetaReferenceSafetyImpl(Safety):
model_dir = resolve_and_get_path(shield_cfg.model) model_dir = resolve_and_get_path(shield_cfg.model)
_ = PromptGuardShield.instance(model_dir) _ = PromptGuardShield.instance(model_dir)
async def shutdown(self) -> None:
pass
async def register_routing_keys(self, routing_keys: List[str]) -> None:
available_shields = [v.value for v in MetaReferenceShieldType]
for key in routing_keys:
if key not in available_shields:
raise ValueError(f"Unknown safety shield type: {key}")
self.routing_keys = routing_keys
def get_routing_keys(self) -> List[str]:
return self.routing_keys
async def run_shield( async def run_shield(
self, self,
shield_type: str, shield_type: str,