From 335dea849a63cd2d1fba7bf3f78262d51989ae0f Mon Sep 17 00:00:00 2001 From: Ashwin Bharambe Date: Wed, 2 Oct 2024 10:09:36 -0700 Subject: [PATCH] fix sample impls --- .../providers/adapters/inference/sample/sample.py | 9 ++++++++- llama_stack/providers/adapters/safety/sample/sample.py | 9 ++++++++- 2 files changed, 16 insertions(+), 2 deletions(-) diff --git a/llama_stack/providers/adapters/inference/sample/sample.py b/llama_stack/providers/adapters/inference/sample/sample.py index cfe773036..7d4e4a837 100644 --- a/llama_stack/providers/adapters/inference/sample/sample.py +++ b/llama_stack/providers/adapters/inference/sample/sample.py @@ -9,10 +9,17 @@ from .config import SampleConfig from llama_stack.apis.inference import * # noqa: F403 +from llama_stack.distribution.datatypes import RoutableProvider -class SampleInferenceImpl(Inference): + +class SampleInferenceImpl(Inference, RoutableProvider): def __init__(self, config: SampleConfig): self.config = config + async def validate_routing_keys(self, routing_keys: list[str]) -> None: + # these are the model names the Llama Stack will use to route requests to this provider + # perform validation here if necessary + pass + async def initialize(self): pass diff --git a/llama_stack/providers/adapters/safety/sample/sample.py b/llama_stack/providers/adapters/safety/sample/sample.py index 4631bde26..a71f5143f 100644 --- a/llama_stack/providers/adapters/safety/sample/sample.py +++ b/llama_stack/providers/adapters/safety/sample/sample.py @@ -9,10 +9,17 @@ from .config import SampleConfig from llama_stack.apis.safety import * # noqa: F403 +from llama_stack.distribution.datatypes import RoutableProvider -class SampleSafetyImpl(Safety): + +class SampleSafetyImpl(Safety, RoutableProvider): def __init__(self, config: SampleConfig): self.config = config + async def validate_routing_keys(self, routing_keys: list[str]) -> None: + # these are the safety shields the Llama Stack will use to route requests to this provider + # perform validation here if necessary + pass + async def initialize(self): pass