Add a RoutableProvider protocol, support for multiple routing keys (#163)

* Update configure.py to use multiple routing keys for safety
* Refactor distribution/datatypes into a providers/datatypes
* Cleanup
This commit is contained in:
Ashwin Bharambe 2024-09-30 17:30:21 -07:00 committed by GitHub
parent 73decb3781
commit eb2d8a31a5
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
24 changed files with 600 additions and 577 deletions

View file

@ -6,21 +6,13 @@
import asyncio
from typing import AsyncIterator, Union
from typing import AsyncIterator, List, Union
from llama_models.llama3.api.datatypes import StopReason
from llama_models.sku_list import resolve_model
from llama_stack.apis.inference import (
ChatCompletionRequest,
ChatCompletionResponse,
ChatCompletionResponseEvent,
ChatCompletionResponseEventType,
ChatCompletionResponseStreamChunk,
Inference,
ToolCallDelta,
ToolCallParseStatus,
)
from llama_models.llama3.api.datatypes import * # noqa: F403
from llama_stack.apis.inference import * # noqa: F403
from llama_stack.distribution.datatypes import RoutableProvider
from llama_stack.providers.utils.inference.augment_messages import (
augment_messages_for_tools,
)
@ -28,15 +20,12 @@ from llama_stack.providers.utils.inference.augment_messages import (
from .config import MetaReferenceImplConfig
from .model_parallel import LlamaModelParallelGenerator
from llama_models.llama3.api.datatypes import * # noqa: F403
from llama_stack.apis.inference import * # noqa: F403
# there's a single model parallel process running serving the model. for now,
# we don't support multiple concurrent requests to this process.
SEMAPHORE = asyncio.Semaphore(1)
class MetaReferenceInferenceImpl(Inference):
class MetaReferenceInferenceImpl(Inference, RoutableProvider):
def __init__(self, config: MetaReferenceImplConfig) -> None:
self.config = config
model = resolve_model(config.model)
@ -49,6 +38,12 @@ class MetaReferenceInferenceImpl(Inference):
self.generator = LlamaModelParallelGenerator(self.config)
self.generator.start()
async def validate_routing_keys(self, routing_keys: List[str]) -> None:
assert (
len(routing_keys) == 1
), f"Only one routing key is supported {routing_keys}"
assert routing_keys[0] == self.config.model
async def shutdown(self) -> None:
self.generator.stop()

View file

@ -14,6 +14,7 @@ import numpy as np
from numpy.typing import NDArray
from llama_models.llama3.api.datatypes import * # noqa: F403
from llama_stack.distribution.datatypes import RoutableProvider
from llama_stack.apis.memory import * # noqa: F403
from llama_stack.providers.utils.memory.vector_store import (
@ -62,7 +63,7 @@ class FaissIndex(EmbeddingIndex):
return QueryDocumentsResponse(chunks=chunks, scores=scores)
class FaissMemoryImpl(Memory):
class FaissMemoryImpl(Memory, RoutableProvider):
def __init__(self, config: FaissImplConfig) -> None:
self.config = config
self.cache = {}
@ -71,6 +72,10 @@ class FaissMemoryImpl(Memory):
async def shutdown(self) -> None: ...
async def validate_routing_keys(self, routing_keys: List[str]) -> None:
print(f"[faiss] Registering memory bank routing keys: {routing_keys}")
pass
async def create_memory_bank(
self,
name: str,

View file

@ -4,13 +4,15 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Any, Dict, List
from llama_models.sku_list import resolve_model
from llama_stack.distribution.utils.model_utils import model_local_dir
from llama_stack.apis.inference import * # noqa: F403
from llama_stack.apis.safety import * # noqa: F403
from llama_models.llama3.api.datatypes import * # noqa: F403
from llama_stack.distribution.datatypes import Api
from llama_stack.distribution.datatypes import Api, RoutableProvider
from llama_stack.providers.impls.meta_reference.safety.shields.base import (
OnViolationAction,
@ -35,7 +37,7 @@ def resolve_and_get_path(model_name: str) -> str:
return model_dir
class MetaReferenceSafetyImpl(Safety):
class MetaReferenceSafetyImpl(Safety, RoutableProvider):
def __init__(self, config: SafetyConfig, deps) -> None:
self.config = config
self.inference_api = deps[Api.inference]
@ -46,6 +48,15 @@ class MetaReferenceSafetyImpl(Safety):
model_dir = resolve_and_get_path(shield_cfg.model)
_ = PromptGuardShield.instance(model_dir)
async def shutdown(self) -> None:
pass
async def validate_routing_keys(self, routing_keys: List[str]) -> None:
available_shields = [v.value for v in MetaReferenceShieldType]
for key in routing_keys:
if key not in available_shields:
raise ValueError(f"Unknown safety shield type: {key}")
async def run_shield(
self,
shield_type: str,