mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-07-29 07:14:20 +00:00
Introduce model_store, shield_store, memory_bank_store
This commit is contained in:
parent
e45a417543
commit
91e0063593
19 changed files with 172 additions and 297 deletions
|
@ -173,7 +173,13 @@ class EmbeddingsResponse(BaseModel):
|
||||||
embeddings: List[List[float]]
|
embeddings: List[List[float]]
|
||||||
|
|
||||||
|
|
||||||
|
class ModelStore(Protocol):
|
||||||
|
def get_model(self, identifier: str) -> ModelDef: ...
|
||||||
|
|
||||||
|
|
||||||
class Inference(Protocol):
|
class Inference(Protocol):
|
||||||
|
model_store: ModelStore
|
||||||
|
|
||||||
@webmethod(route="/inference/completion")
|
@webmethod(route="/inference/completion")
|
||||||
async def completion(
|
async def completion(
|
||||||
self,
|
self,
|
||||||
|
@ -207,9 +213,3 @@ class Inference(Protocol):
|
||||||
|
|
||||||
@webmethod(route="/inference/register_model")
|
@webmethod(route="/inference/register_model")
|
||||||
async def register_model(self, model: ModelDef) -> None: ...
|
async def register_model(self, model: ModelDef) -> None: ...
|
||||||
|
|
||||||
@webmethod(route="/inference/list_models")
|
|
||||||
async def list_models(self) -> List[ModelDef]: ...
|
|
||||||
|
|
||||||
@webmethod(route="/inference/get_model")
|
|
||||||
async def get_model(self, identifier: str) -> Optional[ModelDef]: ...
|
|
||||||
|
|
|
@ -38,7 +38,13 @@ class QueryDocumentsResponse(BaseModel):
|
||||||
scores: List[float]
|
scores: List[float]
|
||||||
|
|
||||||
|
|
||||||
|
class MemoryBankStore(Protocol):
|
||||||
|
def get_memory_bank(self, bank_id: str) -> Optional[MemoryBankDef]: ...
|
||||||
|
|
||||||
|
|
||||||
class Memory(Protocol):
|
class Memory(Protocol):
|
||||||
|
memory_bank_store: MemoryBankStore
|
||||||
|
|
||||||
# this will just block now until documents are inserted, but it should
|
# this will just block now until documents are inserted, but it should
|
||||||
# probably return a Job instance which can be polled for completion
|
# probably return a Job instance which can be polled for completion
|
||||||
@webmethod(route="/memory/insert")
|
@webmethod(route="/memory/insert")
|
||||||
|
@ -80,9 +86,3 @@ class Memory(Protocol):
|
||||||
|
|
||||||
@webmethod(route="/memory/register_memory_bank")
|
@webmethod(route="/memory/register_memory_bank")
|
||||||
async def register_memory_bank(self, memory_bank: MemoryBankDef) -> None: ...
|
async def register_memory_bank(self, memory_bank: MemoryBankDef) -> None: ...
|
||||||
|
|
||||||
@webmethod(route="/memory/list_memory_banks")
|
|
||||||
async def list_memory_banks(self) -> List[MemoryBankDef]: ...
|
|
||||||
|
|
||||||
@webmethod(route="/memory/get_memory_bank")
|
|
||||||
async def get_memory_bank(self, identifier: str) -> Optional[MemoryBankDef]: ...
|
|
||||||
|
|
|
@ -38,7 +38,13 @@ class RunShieldResponse(BaseModel):
|
||||||
violation: Optional[SafetyViolation] = None
|
violation: Optional[SafetyViolation] = None
|
||||||
|
|
||||||
|
|
||||||
|
class ShieldStore(Protocol):
|
||||||
|
def get_shield(self, identifier: str) -> ShieldDef: ...
|
||||||
|
|
||||||
|
|
||||||
class Safety(Protocol):
|
class Safety(Protocol):
|
||||||
|
shield_store: ShieldStore
|
||||||
|
|
||||||
@webmethod(route="/safety/run_shield")
|
@webmethod(route="/safety/run_shield")
|
||||||
async def run_shield(
|
async def run_shield(
|
||||||
self, shield_type: str, messages: List[Message], params: Dict[str, Any] = None
|
self, shield_type: str, messages: List[Message], params: Dict[str, Any] = None
|
||||||
|
@ -46,9 +52,3 @@ class Safety(Protocol):
|
||||||
|
|
||||||
@webmethod(route="/safety/register_shield")
|
@webmethod(route="/safety/register_shield")
|
||||||
async def register_shield(self, shield: ShieldDef) -> None: ...
|
async def register_shield(self, shield: ShieldDef) -> None: ...
|
||||||
|
|
||||||
@webmethod(route="/safety/list_shields")
|
|
||||||
async def list_shields(self) -> List[ShieldDef]: ...
|
|
||||||
|
|
||||||
@webmethod(route="/safety/get_shield")
|
|
||||||
async def get_shield(self, identifier: str) -> Optional[ShieldDef]: ...
|
|
||||||
|
|
|
@ -39,6 +39,14 @@ RoutedProtocol = Union[
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
|
class ModelRegistry(Protocol):
|
||||||
|
def get_model(self, identifier: str) -> ModelDef: ...
|
||||||
|
|
||||||
|
|
||||||
|
class MemoryBankRegistry(Protocol):
|
||||||
|
def get_memory_bank(self, identifier: str) -> MemoryBankDef: ...
|
||||||
|
|
||||||
|
|
||||||
# Example: /inference, /safety
|
# Example: /inference, /safety
|
||||||
class AutoRoutedProviderSpec(ProviderSpec):
|
class AutoRoutedProviderSpec(ProviderSpec):
|
||||||
provider_type: str = "router"
|
provider_type: str = "router"
|
||||||
|
|
|
@ -151,6 +151,7 @@ async def resolve_impls_with_routing(run_config: StackRunConfig) -> Dict[Api, An
|
||||||
deps,
|
deps,
|
||||||
inner_impls,
|
inner_impls,
|
||||||
)
|
)
|
||||||
|
# TODO: ugh slightly redesign this shady looking code
|
||||||
if "inner-" in api_str:
|
if "inner-" in api_str:
|
||||||
inner_impls_by_provider_id[api_str][provider.provider_id] = impl
|
inner_impls_by_provider_id[api_str][provider.provider_id] = impl
|
||||||
else:
|
else:
|
||||||
|
|
|
@ -15,6 +15,20 @@ from llama_stack.apis.memory_banks import * # noqa: F403
|
||||||
from llama_stack.distribution.datatypes import * # noqa: F403
|
from llama_stack.distribution.datatypes import * # noqa: F403
|
||||||
|
|
||||||
|
|
||||||
|
def get_impl_api(p: Any) -> Api:
|
||||||
|
return p.__provider_spec__.api
|
||||||
|
|
||||||
|
|
||||||
|
async def register_object_with_provider(obj: RoutableObject, p: Any) -> None:
|
||||||
|
api = get_impl_api(p)
|
||||||
|
if api == Api.inference:
|
||||||
|
await p.register_model(obj)
|
||||||
|
elif api == Api.safety:
|
||||||
|
await p.register_shield(obj)
|
||||||
|
elif api == Api.memory:
|
||||||
|
await p.register_memory_bank(obj)
|
||||||
|
|
||||||
|
|
||||||
# TODO: this routing table maintains state in memory purely. We need to
|
# TODO: this routing table maintains state in memory purely. We need to
|
||||||
# add persistence to it when we add dynamic registration of objects.
|
# add persistence to it when we add dynamic registration of objects.
|
||||||
class CommonRoutingTableImpl(RoutingTable):
|
class CommonRoutingTableImpl(RoutingTable):
|
||||||
|
@ -32,6 +46,15 @@ class CommonRoutingTableImpl(RoutingTable):
|
||||||
self.impls_by_provider_id = impls_by_provider_id
|
self.impls_by_provider_id = impls_by_provider_id
|
||||||
self.registry = registry
|
self.registry = registry
|
||||||
|
|
||||||
|
for p in self.impls_by_provider_id.values():
|
||||||
|
api = get_impl_api(p)
|
||||||
|
if api == Api.inference:
|
||||||
|
p.model_store = self
|
||||||
|
elif api == Api.safety:
|
||||||
|
p.shield_store = self
|
||||||
|
elif api == Api.memory:
|
||||||
|
p.memory_bank_store = self
|
||||||
|
|
||||||
self.routing_key_to_object = {}
|
self.routing_key_to_object = {}
|
||||||
for obj in self.registry:
|
for obj in self.registry:
|
||||||
self.routing_key_to_object[obj.identifier] = obj
|
self.routing_key_to_object[obj.identifier] = obj
|
||||||
|
@ -39,7 +62,7 @@ class CommonRoutingTableImpl(RoutingTable):
|
||||||
async def initialize(self) -> None:
|
async def initialize(self) -> None:
|
||||||
for obj in self.registry:
|
for obj in self.registry:
|
||||||
p = self.impls_by_provider_id[obj.provider_id]
|
p = self.impls_by_provider_id[obj.provider_id]
|
||||||
await self.register_object(obj, p)
|
await register_object_with_provider(obj, p)
|
||||||
|
|
||||||
async def shutdown(self) -> None:
|
async def shutdown(self) -> None:
|
||||||
for p in self.impls_by_provider_id.values():
|
for p in self.impls_by_provider_id.values():
|
||||||
|
@ -57,7 +80,7 @@ class CommonRoutingTableImpl(RoutingTable):
|
||||||
return obj
|
return obj
|
||||||
return None
|
return None
|
||||||
|
|
||||||
async def register_object_common(self, obj: RoutableObject) -> None:
|
async def register_object(self, obj: RoutableObject) -> Any:
|
||||||
if obj.identifier in self.routing_key_to_object:
|
if obj.identifier in self.routing_key_to_object:
|
||||||
raise ValueError(f"Object `{obj.identifier}` already registered")
|
raise ValueError(f"Object `{obj.identifier}` already registered")
|
||||||
|
|
||||||
|
@ -65,16 +88,13 @@ class CommonRoutingTableImpl(RoutingTable):
|
||||||
raise ValueError(f"Provider `{obj.provider_id}` not found")
|
raise ValueError(f"Provider `{obj.provider_id}` not found")
|
||||||
|
|
||||||
p = self.impls_by_provider_id[obj.provider_id]
|
p = self.impls_by_provider_id[obj.provider_id]
|
||||||
await p.register_object(obj)
|
await register_object_with_provider(obj, p)
|
||||||
|
|
||||||
self.routing_key_to_object[obj.identifier] = obj
|
self.routing_key_to_object[obj.identifier] = obj
|
||||||
self.registry.append(obj)
|
self.registry.append(obj)
|
||||||
|
|
||||||
|
|
||||||
class ModelsRoutingTable(CommonRoutingTableImpl, Models):
|
class ModelsRoutingTable(CommonRoutingTableImpl, Models):
|
||||||
async def register_object(self, obj: ModelDef, p: Inference) -> None:
|
|
||||||
await p.register_model(obj)
|
|
||||||
|
|
||||||
async def list_models(self) -> List[ModelDef]:
|
async def list_models(self) -> List[ModelDef]:
|
||||||
return self.registry
|
return self.registry
|
||||||
|
|
||||||
|
@ -82,13 +102,10 @@ class ModelsRoutingTable(CommonRoutingTableImpl, Models):
|
||||||
return self.get_object_by_identifier(identifier)
|
return self.get_object_by_identifier(identifier)
|
||||||
|
|
||||||
async def register_model(self, model: ModelDef) -> None:
|
async def register_model(self, model: ModelDef) -> None:
|
||||||
await self.register_object_common(model)
|
await self.register_object(model)
|
||||||
|
|
||||||
|
|
||||||
class ShieldsRoutingTable(CommonRoutingTableImpl, Shields):
|
class ShieldsRoutingTable(CommonRoutingTableImpl, Shields):
|
||||||
async def register_object(self, obj: ShieldDef, p: Safety) -> None:
|
|
||||||
await p.register_shield(obj)
|
|
||||||
|
|
||||||
async def list_shields(self) -> List[ShieldDef]:
|
async def list_shields(self) -> List[ShieldDef]:
|
||||||
return self.registry
|
return self.registry
|
||||||
|
|
||||||
|
@ -96,13 +113,10 @@ class ShieldsRoutingTable(CommonRoutingTableImpl, Shields):
|
||||||
return self.get_object_by_identifier(shield_type)
|
return self.get_object_by_identifier(shield_type)
|
||||||
|
|
||||||
async def register_shield(self, shield: ShieldDef) -> None:
|
async def register_shield(self, shield: ShieldDef) -> None:
|
||||||
await self.register_object_common(shield)
|
await self.register_object(shield)
|
||||||
|
|
||||||
|
|
||||||
class MemoryBanksRoutingTable(CommonRoutingTableImpl, MemoryBanks):
|
class MemoryBanksRoutingTable(CommonRoutingTableImpl, MemoryBanks):
|
||||||
async def register_object(self, obj: MemoryBankDef, p: Memory) -> None:
|
|
||||||
await p.register_memory_bank(obj)
|
|
||||||
|
|
||||||
async def list_memory_banks(self) -> List[MemoryBankDef]:
|
async def list_memory_banks(self) -> List[MemoryBankDef]:
|
||||||
return self.registry
|
return self.registry
|
||||||
|
|
||||||
|
@ -110,4 +124,4 @@ class MemoryBanksRoutingTable(CommonRoutingTableImpl, MemoryBanks):
|
||||||
return self.get_object_by_identifier(identifier)
|
return self.get_object_by_identifier(identifier)
|
||||||
|
|
||||||
async def register_memory_bank(self, bank: MemoryBankDef) -> None:
|
async def register_memory_bank(self, bank: MemoryBankDef) -> None:
|
||||||
await self.register_object_common(bank)
|
await self.register_object(bank)
|
||||||
|
|
|
@ -6,39 +6,40 @@
|
||||||
|
|
||||||
from typing import AsyncGenerator
|
from typing import AsyncGenerator
|
||||||
|
|
||||||
from openai import OpenAI
|
|
||||||
|
|
||||||
from llama_models.llama3.api.chat_format import ChatFormat
|
from llama_models.llama3.api.chat_format import ChatFormat
|
||||||
|
|
||||||
from llama_models.llama3.api.datatypes import Message, StopReason
|
from llama_models.llama3.api.datatypes import Message, StopReason
|
||||||
from llama_models.llama3.api.tokenizer import Tokenizer
|
from llama_models.llama3.api.tokenizer import Tokenizer
|
||||||
from llama_models.sku_list import resolve_model
|
|
||||||
|
from openai import OpenAI
|
||||||
|
|
||||||
from llama_stack.apis.inference import * # noqa: F403
|
from llama_stack.apis.inference import * # noqa: F403
|
||||||
from llama_stack.providers.utils.inference.augment_messages import (
|
from llama_stack.providers.utils.inference.augment_messages import (
|
||||||
augment_messages_for_tools,
|
augment_messages_for_tools,
|
||||||
)
|
)
|
||||||
|
from llama_stack.providers.utils.inference.model_registry import ModelRegistryHelper
|
||||||
|
|
||||||
from .config import DatabricksImplConfig
|
from .config import DatabricksImplConfig
|
||||||
|
|
||||||
|
|
||||||
DATABRICKS_SUPPORTED_MODELS = {
|
DATABRICKS_SUPPORTED_MODELS = {
|
||||||
"Llama3.1-70B-Instruct": "databricks-meta-llama-3-1-70b-instruct",
|
"Llama3.1-70B-Instruct": "databricks-meta-llama-3-1-70b-instruct",
|
||||||
"Llama3.1-405B-Instruct": "databricks-meta-llama-3-1-405b-instruct",
|
"Llama3.1-405B-Instruct": "databricks-meta-llama-3-1-405b-instruct",
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
class DatabricksInferenceAdapter(Inference):
|
class DatabricksInferenceAdapter(ModelRegistryHelper, Inference):
|
||||||
def __init__(self, config: DatabricksImplConfig) -> None:
|
def __init__(self, config: DatabricksImplConfig) -> None:
|
||||||
|
ModelRegistryHelper.__init__(
|
||||||
|
self, stack_to_provider_models_map=DATABRICKS_SUPPORTED_MODELS
|
||||||
|
)
|
||||||
self.config = config
|
self.config = config
|
||||||
tokenizer = Tokenizer.get_instance()
|
tokenizer = Tokenizer.get_instance()
|
||||||
self.formatter = ChatFormat(tokenizer)
|
self.formatter = ChatFormat(tokenizer)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def client(self) -> OpenAI:
|
def client(self) -> OpenAI:
|
||||||
return OpenAI(
|
return OpenAI(base_url=self.config.url, api_key=self.config.api_token)
|
||||||
base_url=self.config.url,
|
|
||||||
api_key=self.config.api_token
|
|
||||||
)
|
|
||||||
|
|
||||||
async def initialize(self) -> None:
|
async def initialize(self) -> None:
|
||||||
return
|
return
|
||||||
|
@ -65,18 +66,6 @@ class DatabricksInferenceAdapter(Inference):
|
||||||
|
|
||||||
return databricks_messages
|
return databricks_messages
|
||||||
|
|
||||||
def resolve_databricks_model(self, model_name: str) -> str:
|
|
||||||
model = resolve_model(model_name)
|
|
||||||
assert (
|
|
||||||
model is not None
|
|
||||||
and model.descriptor(shorten_default_variant=True)
|
|
||||||
in DATABRICKS_SUPPORTED_MODELS
|
|
||||||
), f"Unsupported model: {model_name}, use one of the supported models: {','.join(DATABRICKS_SUPPORTED_MODELS.keys())}"
|
|
||||||
|
|
||||||
return DATABRICKS_SUPPORTED_MODELS.get(
|
|
||||||
model.descriptor(shorten_default_variant=True)
|
|
||||||
)
|
|
||||||
|
|
||||||
def get_databricks_chat_options(self, request: ChatCompletionRequest) -> dict:
|
def get_databricks_chat_options(self, request: ChatCompletionRequest) -> dict:
|
||||||
options = {}
|
options = {}
|
||||||
if request.sampling_params is not None:
|
if request.sampling_params is not None:
|
||||||
|
@ -110,10 +99,9 @@ class DatabricksInferenceAdapter(Inference):
|
||||||
|
|
||||||
messages = augment_messages_for_tools(request)
|
messages = augment_messages_for_tools(request)
|
||||||
options = self.get_databricks_chat_options(request)
|
options = self.get_databricks_chat_options(request)
|
||||||
databricks_model = self.resolve_databricks_model(request.model)
|
databricks_model = self.map_to_provider_model(request.model)
|
||||||
|
|
||||||
if not request.stream:
|
if not request.stream:
|
||||||
|
|
||||||
r = self.client.chat.completions.create(
|
r = self.client.chat.completions.create(
|
||||||
model=databricks_model,
|
model=databricks_model,
|
||||||
messages=self._messages_to_databricks_messages(messages),
|
messages=self._messages_to_databricks_messages(messages),
|
||||||
|
@ -154,10 +142,7 @@ class DatabricksInferenceAdapter(Inference):
|
||||||
**options,
|
**options,
|
||||||
):
|
):
|
||||||
if chunk.choices[0].finish_reason:
|
if chunk.choices[0].finish_reason:
|
||||||
if (
|
if stop_reason is None and chunk.choices[0].finish_reason == "stop":
|
||||||
stop_reason is None
|
|
||||||
and chunk.choices[0].finish_reason == "stop"
|
|
||||||
):
|
|
||||||
stop_reason = StopReason.end_of_turn
|
stop_reason = StopReason.end_of_turn
|
||||||
elif (
|
elif (
|
||||||
stop_reason is None
|
stop_reason is None
|
||||||
|
@ -254,4 +239,4 @@ class DatabricksInferenceAdapter(Inference):
|
||||||
delta="",
|
delta="",
|
||||||
stop_reason=stop_reason,
|
stop_reason=stop_reason,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
|
@ -21,7 +21,7 @@ from llama_stack.providers.utils.inference.augment_messages import (
|
||||||
from llama_stack.providers.utils.inference.model_registry import ModelRegistryHelper
|
from llama_stack.providers.utils.inference.model_registry import ModelRegistryHelper
|
||||||
|
|
||||||
|
|
||||||
OLLAMA_SUPPORTED_SKUS = {
|
OLLAMA_SUPPORTED_MODELS = {
|
||||||
"Llama3.1-8B-Instruct": "llama3.1:8b-instruct-fp16",
|
"Llama3.1-8B-Instruct": "llama3.1:8b-instruct-fp16",
|
||||||
"Llama3.1-70B-Instruct": "llama3.1:70b-instruct-fp16",
|
"Llama3.1-70B-Instruct": "llama3.1:70b-instruct-fp16",
|
||||||
"Llama3.2-1B-Instruct": "llama3.2:1b-instruct-fp16",
|
"Llama3.2-1B-Instruct": "llama3.2:1b-instruct-fp16",
|
||||||
|
@ -33,7 +33,7 @@ OLLAMA_SUPPORTED_SKUS = {
|
||||||
class OllamaInferenceAdapter(ModelRegistryHelper, Inference):
|
class OllamaInferenceAdapter(ModelRegistryHelper, Inference):
|
||||||
def __init__(self, url: str) -> None:
|
def __init__(self, url: str) -> None:
|
||||||
ModelRegistryHelper.__init__(
|
ModelRegistryHelper.__init__(
|
||||||
self, stack_to_provider_models_map=OLLAMA_SUPPORTED_SKUS
|
self, stack_to_provider_models_map=OLLAMA_SUPPORTED_MODELS
|
||||||
)
|
)
|
||||||
self.url = url
|
self.url = url
|
||||||
tokenizer = Tokenizer.get_instance()
|
tokenizer = Tokenizer.get_instance()
|
||||||
|
|
|
@ -9,14 +9,12 @@ from .config import SampleConfig
|
||||||
|
|
||||||
from llama_stack.apis.inference import * # noqa: F403
|
from llama_stack.apis.inference import * # noqa: F403
|
||||||
|
|
||||||
from llama_stack.distribution.datatypes import RoutableProvider
|
|
||||||
|
|
||||||
|
class SampleInferenceImpl(Inference):
|
||||||
class SampleInferenceImpl(Inference, RoutableProvider):
|
|
||||||
def __init__(self, config: SampleConfig):
|
def __init__(self, config: SampleConfig):
|
||||||
self.config = config
|
self.config = config
|
||||||
|
|
||||||
async def validate_routing_keys(self, routing_keys: list[str]) -> None:
|
async def register_model(self, model: ModelDef) -> None:
|
||||||
# these are the model names the Llama Stack will use to route requests to this provider
|
# these are the model names the Llama Stack will use to route requests to this provider
|
||||||
# perform validation here if necessary
|
# perform validation here if necessary
|
||||||
pass
|
pass
|
||||||
|
|
|
@ -12,6 +12,7 @@ from huggingface_hub import AsyncInferenceClient, HfApi
|
||||||
from llama_models.llama3.api.chat_format import ChatFormat
|
from llama_models.llama3.api.chat_format import ChatFormat
|
||||||
from llama_models.llama3.api.datatypes import StopReason
|
from llama_models.llama3.api.datatypes import StopReason
|
||||||
from llama_models.llama3.api.tokenizer import Tokenizer
|
from llama_models.llama3.api.tokenizer import Tokenizer
|
||||||
|
from llama_models.sku_list import resolve_model
|
||||||
|
|
||||||
from llama_stack.apis.inference import * # noqa: F403
|
from llama_stack.apis.inference import * # noqa: F403
|
||||||
from llama_stack.providers.utils.inference.augment_messages import (
|
from llama_stack.providers.utils.inference.augment_messages import (
|
||||||
|
@ -32,16 +33,18 @@ class _HfAdapter(Inference):
|
||||||
self.tokenizer = Tokenizer.get_instance()
|
self.tokenizer = Tokenizer.get_instance()
|
||||||
self.formatter = ChatFormat(self.tokenizer)
|
self.formatter = ChatFormat(self.tokenizer)
|
||||||
|
|
||||||
# TODO: make this work properly by checking this against the model_id being
|
|
||||||
# served by the remote endpoint
|
|
||||||
async def register_model(self, model: ModelDef) -> None:
|
async def register_model(self, model: ModelDef) -> None:
|
||||||
pass
|
resolved_model = resolve_model(model.identifier)
|
||||||
|
if resolved_model is None:
|
||||||
|
raise ValueError(f"Unknown model: {model.identifier}")
|
||||||
|
|
||||||
async def list_models(self) -> List[ModelDef]:
|
if not resolved_model.huggingface_repo:
|
||||||
return []
|
raise ValueError(
|
||||||
|
f"Model {model.identifier} does not have a HuggingFace repo"
|
||||||
|
)
|
||||||
|
|
||||||
async def get_model(self, identifier: str) -> Optional[ModelDef]:
|
if self.model_id != resolved_model.huggingface_repo:
|
||||||
return None
|
raise ValueError(f"Model mismatch: {model.identifier} != {self.model_id}")
|
||||||
|
|
||||||
async def shutdown(self) -> None:
|
async def shutdown(self) -> None:
|
||||||
pass
|
pass
|
||||||
|
|
|
@ -5,7 +5,6 @@
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
import json
|
import json
|
||||||
import uuid
|
|
||||||
from typing import List
|
from typing import List
|
||||||
from urllib.parse import urlparse
|
from urllib.parse import urlparse
|
||||||
|
|
||||||
|
@ -13,7 +12,6 @@ import chromadb
|
||||||
from numpy.typing import NDArray
|
from numpy.typing import NDArray
|
||||||
|
|
||||||
from llama_stack.apis.memory import * # noqa: F403
|
from llama_stack.apis.memory import * # noqa: F403
|
||||||
from llama_stack.distribution.datatypes import RoutableProvider
|
|
||||||
|
|
||||||
from llama_stack.providers.utils.memory.vector_store import (
|
from llama_stack.providers.utils.memory.vector_store import (
|
||||||
BankWithIndex,
|
BankWithIndex,
|
||||||
|
@ -65,7 +63,7 @@ class ChromaIndex(EmbeddingIndex):
|
||||||
return QueryDocumentsResponse(chunks=chunks, scores=scores)
|
return QueryDocumentsResponse(chunks=chunks, scores=scores)
|
||||||
|
|
||||||
|
|
||||||
class ChromaMemoryAdapter(Memory, RoutableProvider):
|
class ChromaMemoryAdapter(Memory):
|
||||||
def __init__(self, url: str) -> None:
|
def __init__(self, url: str) -> None:
|
||||||
print(f"Initializing ChromaMemoryAdapter with url: {url}")
|
print(f"Initializing ChromaMemoryAdapter with url: {url}")
|
||||||
url = url.rstrip("/")
|
url = url.rstrip("/")
|
||||||
|
@ -93,48 +91,33 @@ class ChromaMemoryAdapter(Memory, RoutableProvider):
|
||||||
async def shutdown(self) -> None:
|
async def shutdown(self) -> None:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
async def validate_routing_keys(self, routing_keys: List[str]) -> None:
|
async def register_memory_bank(
|
||||||
print(f"[chroma] Registering memory bank routing keys: {routing_keys}")
|
|
||||||
pass
|
|
||||||
|
|
||||||
async def create_memory_bank(
|
|
||||||
self,
|
self,
|
||||||
name: str,
|
memory_bank: MemoryBankDef,
|
||||||
config: MemoryBankConfig,
|
) -> None:
|
||||||
url: Optional[URL] = None,
|
assert (
|
||||||
) -> MemoryBank:
|
memory_bank.type == MemoryBankType.vector.value
|
||||||
bank_id = str(uuid.uuid4())
|
), f"Only vector banks are supported {memory_bank.type}"
|
||||||
bank = MemoryBank(
|
|
||||||
bank_id=bank_id,
|
|
||||||
name=name,
|
|
||||||
config=config,
|
|
||||||
url=url,
|
|
||||||
)
|
|
||||||
collection = await self.client.create_collection(
|
collection = await self.client.create_collection(
|
||||||
name=bank_id,
|
name=memory_bank.identifier,
|
||||||
metadata={"bank": bank.json()},
|
|
||||||
)
|
)
|
||||||
bank_index = BankWithIndex(
|
bank_index = BankWithIndex(
|
||||||
bank=bank, index=ChromaIndex(self.client, collection)
|
bank=memory_bank, index=ChromaIndex(self.client, collection)
|
||||||
)
|
)
|
||||||
self.cache[bank_id] = bank_index
|
self.cache[memory_bank.identifier] = bank_index
|
||||||
return bank
|
|
||||||
|
|
||||||
async def get_memory_bank(self, bank_id: str) -> Optional[MemoryBank]:
|
|
||||||
bank_index = await self._get_and_cache_bank_index(bank_id)
|
|
||||||
if bank_index is None:
|
|
||||||
return None
|
|
||||||
return bank_index.bank
|
|
||||||
|
|
||||||
async def _get_and_cache_bank_index(self, bank_id: str) -> Optional[BankWithIndex]:
|
async def _get_and_cache_bank_index(self, bank_id: str) -> Optional[BankWithIndex]:
|
||||||
if bank_id in self.cache:
|
if bank_id in self.cache:
|
||||||
return self.cache[bank_id]
|
return self.cache[bank_id]
|
||||||
|
|
||||||
|
bank = await self.memory_bank_store.get_memory_bank(bank_id)
|
||||||
|
if bank is None:
|
||||||
|
raise ValueError(f"Bank {bank_id} not found")
|
||||||
|
|
||||||
collections = await self.client.list_collections()
|
collections = await self.client.list_collections()
|
||||||
for collection in collections:
|
for collection in collections:
|
||||||
if collection.name == bank_id:
|
if collection.name == bank_id:
|
||||||
print(collection.metadata)
|
|
||||||
bank = MemoryBank(**json.loads(collection.metadata["bank"]))
|
|
||||||
index = BankWithIndex(
|
index = BankWithIndex(
|
||||||
bank=bank,
|
bank=bank,
|
||||||
index=ChromaIndex(self.client, collection),
|
index=ChromaIndex(self.client, collection),
|
||||||
|
|
|
@ -4,18 +4,14 @@
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
import uuid
|
from typing import List
|
||||||
from typing import List, Tuple
|
|
||||||
|
|
||||||
import psycopg2
|
import psycopg2
|
||||||
from numpy.typing import NDArray
|
from numpy.typing import NDArray
|
||||||
from psycopg2 import sql
|
from psycopg2 import sql
|
||||||
from psycopg2.extras import execute_values, Json
|
from psycopg2.extras import execute_values, Json
|
||||||
|
|
||||||
from pydantic import BaseModel
|
|
||||||
|
|
||||||
from llama_stack.apis.memory import * # noqa: F403
|
from llama_stack.apis.memory import * # noqa: F403
|
||||||
from llama_stack.distribution.datatypes import RoutableProvider
|
|
||||||
|
|
||||||
from llama_stack.providers.utils.memory.vector_store import (
|
from llama_stack.providers.utils.memory.vector_store import (
|
||||||
ALL_MINILM_L6_V2_DIMENSION,
|
ALL_MINILM_L6_V2_DIMENSION,
|
||||||
|
@ -32,33 +28,6 @@ def check_extension_version(cur):
|
||||||
return result[0] if result else None
|
return result[0] if result else None
|
||||||
|
|
||||||
|
|
||||||
def upsert_models(cur, keys_models: List[Tuple[str, BaseModel]]):
|
|
||||||
query = sql.SQL(
|
|
||||||
"""
|
|
||||||
INSERT INTO metadata_store (key, data)
|
|
||||||
VALUES %s
|
|
||||||
ON CONFLICT (key) DO UPDATE
|
|
||||||
SET data = EXCLUDED.data
|
|
||||||
"""
|
|
||||||
)
|
|
||||||
|
|
||||||
values = [(key, Json(model.dict())) for key, model in keys_models]
|
|
||||||
execute_values(cur, query, values, template="(%s, %s)")
|
|
||||||
|
|
||||||
|
|
||||||
def load_models(cur, keys: List[str], cls):
|
|
||||||
query = "SELECT key, data FROM metadata_store"
|
|
||||||
if keys:
|
|
||||||
placeholders = ",".join(["%s"] * len(keys))
|
|
||||||
query += f" WHERE key IN ({placeholders})"
|
|
||||||
cur.execute(query, keys)
|
|
||||||
else:
|
|
||||||
cur.execute(query)
|
|
||||||
|
|
||||||
rows = cur.fetchall()
|
|
||||||
return [cls(**row["data"]) for row in rows]
|
|
||||||
|
|
||||||
|
|
||||||
class PGVectorIndex(EmbeddingIndex):
|
class PGVectorIndex(EmbeddingIndex):
|
||||||
def __init__(self, bank: MemoryBank, dimension: int, cursor):
|
def __init__(self, bank: MemoryBank, dimension: int, cursor):
|
||||||
self.cursor = cursor
|
self.cursor = cursor
|
||||||
|
@ -119,7 +88,7 @@ class PGVectorIndex(EmbeddingIndex):
|
||||||
return QueryDocumentsResponse(chunks=chunks, scores=scores)
|
return QueryDocumentsResponse(chunks=chunks, scores=scores)
|
||||||
|
|
||||||
|
|
||||||
class PGVectorMemoryAdapter(Memory, RoutableProvider):
|
class PGVectorMemoryAdapter(Memory):
|
||||||
def __init__(self, config: PGVectorConfig) -> None:
|
def __init__(self, config: PGVectorConfig) -> None:
|
||||||
print(f"Initializing PGVectorMemoryAdapter -> {config.host}:{config.port}")
|
print(f"Initializing PGVectorMemoryAdapter -> {config.host}:{config.port}")
|
||||||
self.config = config
|
self.config = config
|
||||||
|
@ -144,14 +113,6 @@ class PGVectorMemoryAdapter(Memory, RoutableProvider):
|
||||||
else:
|
else:
|
||||||
raise RuntimeError("Vector extension is not installed.")
|
raise RuntimeError("Vector extension is not installed.")
|
||||||
|
|
||||||
self.cursor.execute(
|
|
||||||
"""
|
|
||||||
CREATE TABLE IF NOT EXISTS metadata_store (
|
|
||||||
key TEXT PRIMARY KEY,
|
|
||||||
data JSONB
|
|
||||||
)
|
|
||||||
"""
|
|
||||||
)
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
import traceback
|
import traceback
|
||||||
|
|
||||||
|
@ -161,51 +122,28 @@ class PGVectorMemoryAdapter(Memory, RoutableProvider):
|
||||||
async def shutdown(self) -> None:
|
async def shutdown(self) -> None:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
async def validate_routing_keys(self, routing_keys: List[str]) -> None:
|
async def register_memory_bank(
|
||||||
print(f"[pgvector] Registering memory bank routing keys: {routing_keys}")
|
|
||||||
pass
|
|
||||||
|
|
||||||
async def create_memory_bank(
|
|
||||||
self,
|
self,
|
||||||
name: str,
|
memory_bank: MemoryBankDef,
|
||||||
config: MemoryBankConfig,
|
) -> None:
|
||||||
url: Optional[URL] = None,
|
assert (
|
||||||
) -> MemoryBank:
|
memory_bank.type == MemoryBankType.vector.value
|
||||||
bank_id = str(uuid.uuid4())
|
), f"Only vector banks are supported {memory_bank.type}"
|
||||||
bank = MemoryBank(
|
|
||||||
bank_id=bank_id,
|
|
||||||
name=name,
|
|
||||||
config=config,
|
|
||||||
url=url,
|
|
||||||
)
|
|
||||||
upsert_models(
|
|
||||||
self.cursor,
|
|
||||||
[
|
|
||||||
(bank.bank_id, bank),
|
|
||||||
],
|
|
||||||
)
|
|
||||||
index = BankWithIndex(
|
index = BankWithIndex(
|
||||||
bank=bank,
|
bank=memory_bank,
|
||||||
index=PGVectorIndex(bank, ALL_MINILM_L6_V2_DIMENSION, self.cursor),
|
index=PGVectorIndex(memory_bank, ALL_MINILM_L6_V2_DIMENSION, self.cursor),
|
||||||
)
|
)
|
||||||
self.cache[bank_id] = index
|
self.cache[bank_id] = index
|
||||||
return bank
|
|
||||||
|
|
||||||
async def get_memory_bank(self, bank_id: str) -> Optional[MemoryBank]:
|
|
||||||
bank_index = await self._get_and_cache_bank_index(bank_id)
|
|
||||||
if bank_index is None:
|
|
||||||
return None
|
|
||||||
return bank_index.bank
|
|
||||||
|
|
||||||
async def _get_and_cache_bank_index(self, bank_id: str) -> Optional[BankWithIndex]:
|
async def _get_and_cache_bank_index(self, bank_id: str) -> Optional[BankWithIndex]:
|
||||||
if bank_id in self.cache:
|
if bank_id in self.cache:
|
||||||
return self.cache[bank_id]
|
return self.cache[bank_id]
|
||||||
|
|
||||||
banks = load_models(self.cursor, [bank_id], MemoryBank)
|
bank = await self.memory_bank_store.get_memory_bank(bank_id)
|
||||||
if not banks:
|
if not bank:
|
||||||
return None
|
raise ValueError(f"Bank {bank_id} not found")
|
||||||
|
|
||||||
bank = banks[0]
|
|
||||||
index = BankWithIndex(
|
index = BankWithIndex(
|
||||||
bank=bank,
|
bank=bank,
|
||||||
index=PGVectorIndex(bank, ALL_MINILM_L6_V2_DIMENSION, self.cursor),
|
index=PGVectorIndex(bank, ALL_MINILM_L6_V2_DIMENSION, self.cursor),
|
||||||
|
|
|
@ -9,14 +9,12 @@ from .config import SampleConfig
|
||||||
|
|
||||||
from llama_stack.apis.memory import * # noqa: F403
|
from llama_stack.apis.memory import * # noqa: F403
|
||||||
|
|
||||||
from llama_stack.distribution.datatypes import RoutableProvider
|
|
||||||
|
|
||||||
|
class SampleMemoryImpl(Memory):
|
||||||
class SampleMemoryImpl(Memory, RoutableProvider):
|
|
||||||
def __init__(self, config: SampleConfig):
|
def __init__(self, config: SampleConfig):
|
||||||
self.config = config
|
self.config = config
|
||||||
|
|
||||||
async def validate_routing_keys(self, routing_keys: list[str]) -> None:
|
async def register_memory_bank(self, memory_bank: MemoryBankDef) -> None:
|
||||||
# these are the memory banks the Llama Stack will use to route requests to this provider
|
# these are the memory banks the Llama Stack will use to route requests to this provider
|
||||||
# perform validation here if necessary
|
# perform validation here if necessary
|
||||||
pass
|
pass
|
||||||
|
|
|
@ -7,14 +7,12 @@
|
||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
|
|
||||||
import traceback
|
|
||||||
from typing import Any, Dict, List
|
from typing import Any, Dict, List
|
||||||
|
|
||||||
import boto3
|
import boto3
|
||||||
|
|
||||||
from llama_stack.apis.safety import * # noqa
|
from llama_stack.apis.safety import * # noqa
|
||||||
from llama_models.llama3.api.datatypes import * # noqa: F403
|
from llama_models.llama3.api.datatypes import * # noqa: F403
|
||||||
from llama_stack.distribution.datatypes import RoutableProvider
|
|
||||||
|
|
||||||
from .config import BedrockSafetyConfig
|
from .config import BedrockSafetyConfig
|
||||||
|
|
||||||
|
@ -22,16 +20,17 @@ from .config import BedrockSafetyConfig
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
SUPPORTED_SHIELD_TYPES = [
|
BEDROCK_SUPPORTED_SHIELDS = [
|
||||||
"bedrock_guardrail",
|
ShieldType.generic_content_shield.value,
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
class BedrockSafetyAdapter(Safety, RoutableProvider):
|
class BedrockSafetyAdapter(Safety):
|
||||||
def __init__(self, config: BedrockSafetyConfig) -> None:
|
def __init__(self, config: BedrockSafetyConfig) -> None:
|
||||||
if not config.aws_profile:
|
if not config.aws_profile:
|
||||||
raise ValueError(f"Missing boto_client aws_profile in model info::{config}")
|
raise ValueError(f"Missing boto_client aws_profile in model info::{config}")
|
||||||
self.config = config
|
self.config = config
|
||||||
|
self.registered_shields = []
|
||||||
|
|
||||||
async def initialize(self) -> None:
|
async def initialize(self) -> None:
|
||||||
try:
|
try:
|
||||||
|
@ -45,16 +44,27 @@ class BedrockSafetyAdapter(Safety, RoutableProvider):
|
||||||
async def shutdown(self) -> None:
|
async def shutdown(self) -> None:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
async def validate_routing_keys(self, routing_keys: List[str]) -> None:
|
async def register_shield(self, shield: ShieldDef) -> None:
|
||||||
for key in routing_keys:
|
if shield.type not in BEDROCK_SUPPORTED_SHIELDS:
|
||||||
if key not in SUPPORTED_SHIELD_TYPES:
|
raise ValueError(f"Unsupported safety shield type: {shield.type}")
|
||||||
raise ValueError(f"Unknown safety shield type: {key}")
|
|
||||||
|
shield_params = shield.params
|
||||||
|
if "guardrailIdentifier" not in shield_params:
|
||||||
|
raise ValueError(
|
||||||
|
"Error running request for BedrockGaurdrails:Missing GuardrailID in request"
|
||||||
|
)
|
||||||
|
|
||||||
|
if "guardrailVersion" not in shield_params:
|
||||||
|
raise ValueError(
|
||||||
|
"Error running request for BedrockGaurdrails:Missing guardrailVersion in request"
|
||||||
|
)
|
||||||
|
|
||||||
async def run_shield(
|
async def run_shield(
|
||||||
self, shield_type: str, messages: List[Message], params: Dict[str, Any] = None
|
self, shield_type: str, messages: List[Message], params: Dict[str, Any] = None
|
||||||
) -> RunShieldResponse:
|
) -> RunShieldResponse:
|
||||||
if shield_type not in SUPPORTED_SHIELD_TYPES:
|
shield_def = await self.shield_store.get_shield(shield_type)
|
||||||
raise ValueError(f"Unknown safety shield type: {shield_type}")
|
if not shield_def:
|
||||||
|
raise ValueError(f"Unknown shield {shield_type}")
|
||||||
|
|
||||||
"""This is the implementation for the bedrock guardrails. The input to the guardrails is to be of this format
|
"""This is the implementation for the bedrock guardrails. The input to the guardrails is to be of this format
|
||||||
```content = [
|
```content = [
|
||||||
|
@ -69,52 +79,38 @@ class BedrockSafetyAdapter(Safety, RoutableProvider):
|
||||||
|
|
||||||
They contain content, role . For now we will extract the content and default the "qualifiers": ["query"]
|
They contain content, role . For now we will extract the content and default the "qualifiers": ["query"]
|
||||||
"""
|
"""
|
||||||
try:
|
|
||||||
logger.debug(f"run_shield::{params}::messages={messages}")
|
|
||||||
if "guardrailIdentifier" not in params:
|
|
||||||
raise RuntimeError(
|
|
||||||
"Error running request for BedrockGaurdrails:Missing GuardrailID in request"
|
|
||||||
)
|
|
||||||
|
|
||||||
if "guardrailVersion" not in params:
|
shield_params = shield_def.params
|
||||||
raise RuntimeError(
|
logger.debug(f"run_shield::{shield_params}::messages={messages}")
|
||||||
"Error running request for BedrockGaurdrails:Missing guardrailVersion in request"
|
|
||||||
)
|
|
||||||
|
|
||||||
# - convert the messages into format Bedrock expects
|
# - convert the messages into format Bedrock expects
|
||||||
content_messages = []
|
content_messages = []
|
||||||
for message in messages:
|
for message in messages:
|
||||||
content_messages.append({"text": {"text": message.content}})
|
content_messages.append({"text": {"text": message.content}})
|
||||||
logger.debug(
|
logger.debug(
|
||||||
f"run_shield::final:messages::{json.dumps(content_messages, indent=2)}:"
|
f"run_shield::final:messages::{json.dumps(content_messages, indent=2)}:"
|
||||||
)
|
)
|
||||||
|
|
||||||
response = self.boto_client.apply_guardrail(
|
response = self.boto_client.apply_guardrail(
|
||||||
guardrailIdentifier=params.get("guardrailIdentifier"),
|
guardrailIdentifier=shield_params["guardrailIdentifier"],
|
||||||
guardrailVersion=params.get("guardrailVersion"),
|
guardrailVersion=shield_params["guardrailVersion"],
|
||||||
source="OUTPUT", # or 'INPUT' depending on your use case
|
source="OUTPUT", # or 'INPUT' depending on your use case
|
||||||
content=content_messages,
|
content=content_messages,
|
||||||
)
|
)
|
||||||
logger.debug(f"run_shield:: response: {response}::")
|
if response["action"] == "GUARDRAIL_INTERVENED":
|
||||||
if response["action"] == "GUARDRAIL_INTERVENED":
|
user_message = ""
|
||||||
user_message = ""
|
metadata = {}
|
||||||
metadata = {}
|
for output in response["outputs"]:
|
||||||
for output in response["outputs"]:
|
# guardrails returns a list - however for this implementation we will leverage the last values
|
||||||
# guardrails returns a list - however for this implementation we will leverage the last values
|
user_message = output["text"]
|
||||||
user_message = output["text"]
|
for assessment in response["assessments"]:
|
||||||
for assessment in response["assessments"]:
|
# guardrails returns a list - however for this implementation we will leverage the last values
|
||||||
# guardrails returns a list - however for this implementation we will leverage the last values
|
metadata = dict(assessment)
|
||||||
metadata = dict(assessment)
|
|
||||||
return SafetyViolation(
|
|
||||||
user_message=user_message,
|
|
||||||
violation_level=ViolationLevel.ERROR,
|
|
||||||
metadata=metadata,
|
|
||||||
)
|
|
||||||
|
|
||||||
except Exception:
|
return SafetyViolation(
|
||||||
error_str = traceback.format_exc()
|
user_message=user_message,
|
||||||
logger.error(
|
violation_level=ViolationLevel.ERROR,
|
||||||
f"Error in apply_guardrails:{error_str}:: RETURNING None !!!!!"
|
metadata=metadata,
|
||||||
)
|
)
|
||||||
|
|
||||||
return None
|
return None
|
||||||
|
|
|
@ -9,14 +9,12 @@ from .config import SampleConfig
|
||||||
|
|
||||||
from llama_stack.apis.safety import * # noqa: F403
|
from llama_stack.apis.safety import * # noqa: F403
|
||||||
|
|
||||||
from llama_stack.distribution.datatypes import RoutableProvider
|
|
||||||
|
|
||||||
|
class SampleSafetyImpl(Safety):
|
||||||
class SampleSafetyImpl(Safety, RoutableProvider):
|
|
||||||
def __init__(self, config: SampleConfig):
|
def __init__(self, config: SampleConfig):
|
||||||
self.config = config
|
self.config = config
|
||||||
|
|
||||||
async def validate_routing_keys(self, routing_keys: list[str]) -> None:
|
async def register_shield(self, shield: ShieldDef) -> None:
|
||||||
# these are the safety shields the Llama Stack will use to route requests to this provider
|
# these are the safety shields the Llama Stack will use to route requests to this provider
|
||||||
# perform validation here if necessary
|
# perform validation here if necessary
|
||||||
pass
|
pass
|
||||||
|
|
|
@ -12,7 +12,7 @@ from llama_stack.distribution.request_headers import NeedsRequestProviderData
|
||||||
from .config import TogetherSafetyConfig
|
from .config import TogetherSafetyConfig
|
||||||
|
|
||||||
|
|
||||||
SAFETY_SHIELD_MODEL_MAP = {
|
TOGETHER_SHIELD_MODEL_MAP = {
|
||||||
"llama_guard": "meta-llama/Meta-Llama-Guard-3-8B",
|
"llama_guard": "meta-llama/Meta-Llama-Guard-3-8B",
|
||||||
"Llama-Guard-3-8B": "meta-llama/Meta-Llama-Guard-3-8B",
|
"Llama-Guard-3-8B": "meta-llama/Meta-Llama-Guard-3-8B",
|
||||||
"Llama-Guard-3-11B-Vision": "meta-llama/Llama-Guard-3-11B-Vision-Turbo",
|
"Llama-Guard-3-11B-Vision": "meta-llama/Llama-Guard-3-11B-Vision-Turbo",
|
||||||
|
@ -22,7 +22,6 @@ SAFETY_SHIELD_MODEL_MAP = {
|
||||||
class TogetherSafetyImpl(Safety, NeedsRequestProviderData):
|
class TogetherSafetyImpl(Safety, NeedsRequestProviderData):
|
||||||
def __init__(self, config: TogetherSafetyConfig) -> None:
|
def __init__(self, config: TogetherSafetyConfig) -> None:
|
||||||
self.config = config
|
self.config = config
|
||||||
self.register_shields = []
|
|
||||||
|
|
||||||
async def initialize(self) -> None:
|
async def initialize(self) -> None:
|
||||||
pass
|
pass
|
||||||
|
@ -34,26 +33,15 @@ class TogetherSafetyImpl(Safety, NeedsRequestProviderData):
|
||||||
if shield.type != ShieldType.llama_guard.value:
|
if shield.type != ShieldType.llama_guard.value:
|
||||||
raise ValueError(f"Unsupported safety shield type: {shield.type}")
|
raise ValueError(f"Unsupported safety shield type: {shield.type}")
|
||||||
|
|
||||||
self.registered_shields.append(shield)
|
|
||||||
|
|
||||||
async def list_shields(self) -> List[ShieldDef]:
|
|
||||||
return self.registered_shields
|
|
||||||
|
|
||||||
async def get_shield(self, identifier: str) -> Optional[ShieldDef]:
|
|
||||||
for shield in self.registered_shields:
|
|
||||||
if shield.identifier == identifier:
|
|
||||||
return shield
|
|
||||||
return None
|
|
||||||
|
|
||||||
async def run_shield(
|
async def run_shield(
|
||||||
self, shield_type: str, messages: List[Message], params: Dict[str, Any] = None
|
self, shield_type: str, messages: List[Message], params: Dict[str, Any] = None
|
||||||
) -> RunShieldResponse:
|
) -> RunShieldResponse:
|
||||||
shield_def = await self.get_shield(shield_type)
|
shield_def = await self.shield_store.get_shield(shield_type)
|
||||||
if not shield_def:
|
if not shield_def:
|
||||||
raise ValueError(f"Unknown shield {shield_type}")
|
raise ValueError(f"Unknown shield {shield_type}")
|
||||||
|
|
||||||
model = shield_def.params.get("model", "llama_guard")
|
model = shield_def.params.get("model", "llama_guard")
|
||||||
if model not in SAFETY_SHIELD_MODEL_MAP:
|
if model not in TOGETHER_SHIELD_MODEL_MAP:
|
||||||
raise ValueError(f"Unsupported safety model: {model}")
|
raise ValueError(f"Unsupported safety model: {model}")
|
||||||
|
|
||||||
together_api_key = None
|
together_api_key = None
|
||||||
|
@ -73,7 +61,9 @@ class TogetherSafetyImpl(Safety, NeedsRequestProviderData):
|
||||||
if message.role in (Role.user.value, Role.assistant.value):
|
if message.role in (Role.user.value, Role.assistant.value):
|
||||||
api_messages.append({"role": message.role, "content": message.content})
|
api_messages.append({"role": message.role, "content": message.content})
|
||||||
|
|
||||||
violation = await get_safety_response(together_api_key, model, api_messages)
|
violation = await get_safety_response(
|
||||||
|
together_api_key, TOGETHER_SHIELD_MODEL_MAP[model], api_messages
|
||||||
|
)
|
||||||
return RunShieldResponse(violation=violation)
|
return RunShieldResponse(violation=violation)
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -83,15 +83,6 @@ class FaissMemoryImpl(Memory):
|
||||||
)
|
)
|
||||||
self.cache[memory_bank.identifier] = index
|
self.cache[memory_bank.identifier] = index
|
||||||
|
|
||||||
async def get_memory_bank(self, identifier: str) -> Optional[MemoryBankDef]:
|
|
||||||
index = self.cache.get(identifier)
|
|
||||||
if index is None:
|
|
||||||
return None
|
|
||||||
return index.bank
|
|
||||||
|
|
||||||
async def list_memory_banks(self) -> List[MemoryBankDef]:
|
|
||||||
return [x.bank for x in self.cache.values()]
|
|
||||||
|
|
||||||
async def insert_documents(
|
async def insert_documents(
|
||||||
self,
|
self,
|
||||||
bank_id: str,
|
bank_id: str,
|
||||||
|
|
|
@ -33,7 +33,6 @@ class MetaReferenceSafetyImpl(Safety):
|
||||||
def __init__(self, config: SafetyConfig, deps) -> None:
|
def __init__(self, config: SafetyConfig, deps) -> None:
|
||||||
self.config = config
|
self.config = config
|
||||||
self.inference_api = deps[Api.inference]
|
self.inference_api = deps[Api.inference]
|
||||||
self.registered_shields = []
|
|
||||||
|
|
||||||
self.available_shields = [ShieldType.code_scanner.value]
|
self.available_shields = [ShieldType.code_scanner.value]
|
||||||
if config.llama_guard_shield:
|
if config.llama_guard_shield:
|
||||||
|
@ -55,24 +54,13 @@ class MetaReferenceSafetyImpl(Safety):
|
||||||
if shield.type not in self.available_shields:
|
if shield.type not in self.available_shields:
|
||||||
raise ValueError(f"Unsupported safety shield type: {shield.type}")
|
raise ValueError(f"Unsupported safety shield type: {shield.type}")
|
||||||
|
|
||||||
self.registered_shields.append(shield)
|
|
||||||
|
|
||||||
async def list_shields(self) -> List[ShieldDef]:
|
|
||||||
return self.registered_shields
|
|
||||||
|
|
||||||
async def get_shield(self, identifier: str) -> Optional[ShieldDef]:
|
|
||||||
for shield in self.registered_shields:
|
|
||||||
if shield.identifier == identifier:
|
|
||||||
return shield
|
|
||||||
return None
|
|
||||||
|
|
||||||
async def run_shield(
|
async def run_shield(
|
||||||
self,
|
self,
|
||||||
shield_type: str,
|
shield_type: str,
|
||||||
messages: List[Message],
|
messages: List[Message],
|
||||||
params: Dict[str, Any] = None,
|
params: Dict[str, Any] = None,
|
||||||
) -> RunShieldResponse:
|
) -> RunShieldResponse:
|
||||||
shield_def = await self.get_shield(shield_type)
|
shield_def = await self.shield_store.get_shield(shield_type)
|
||||||
if not shield_def:
|
if not shield_def:
|
||||||
raise ValueError(f"Unknown shield {shield_type}")
|
raise ValueError(f"Unknown shield {shield_type}")
|
||||||
|
|
||||||
|
|
|
@ -4,7 +4,7 @@
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
from typing import Dict, List
|
from typing import Dict
|
||||||
|
|
||||||
from llama_models.sku_list import resolve_model
|
from llama_models.sku_list import resolve_model
|
||||||
|
|
||||||
|
@ -15,7 +15,6 @@ class ModelRegistryHelper:
|
||||||
|
|
||||||
def __init__(self, stack_to_provider_models_map: Dict[str, str]):
|
def __init__(self, stack_to_provider_models_map: Dict[str, str]):
|
||||||
self.stack_to_provider_models_map = stack_to_provider_models_map
|
self.stack_to_provider_models_map = stack_to_provider_models_map
|
||||||
self.registered_models = []
|
|
||||||
|
|
||||||
def map_to_provider_model(self, identifier: str) -> str:
|
def map_to_provider_model(self, identifier: str) -> str:
|
||||||
model = resolve_model(identifier)
|
model = resolve_model(identifier)
|
||||||
|
@ -30,22 +29,7 @@ class ModelRegistryHelper:
|
||||||
return self.stack_to_provider_models_map[identifier]
|
return self.stack_to_provider_models_map[identifier]
|
||||||
|
|
||||||
async def register_model(self, model: ModelDef) -> None:
|
async def register_model(self, model: ModelDef) -> None:
|
||||||
existing = await self.get_model(model.identifier)
|
|
||||||
if existing is not None:
|
|
||||||
return
|
|
||||||
|
|
||||||
if model.identifier not in self.stack_to_provider_models_map:
|
if model.identifier not in self.stack_to_provider_models_map:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"Unsupported model {model.identifier}. Supported models: {self.stack_to_provider_models_map.keys()}"
|
f"Unsupported model {model.identifier}. Supported models: {self.stack_to_provider_models_map.keys()}"
|
||||||
)
|
)
|
||||||
|
|
||||||
self.registered_models.append(model)
|
|
||||||
|
|
||||||
async def list_models(self) -> List[ModelDef]:
|
|
||||||
return self.registered_models
|
|
||||||
|
|
||||||
async def get_model(self, identifier: str) -> Optional[ModelDef]:
|
|
||||||
for model in self.registered_models:
|
|
||||||
if model.identifier == identifier:
|
|
||||||
return model
|
|
||||||
return None
|
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue