Add a RoutableProvider protocol, support for multiple routing keys (#163)

* Update configure.py to use multiple routing keys for safety
* Refactor distribution/datatypes into a providers/datatypes
* Cleanup
This commit is contained in:
Ashwin Bharambe 2024-09-30 17:30:21 -07:00 committed by GitHub
parent 73decb3781
commit eb2d8a31a5
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
24 changed files with 600 additions and 577 deletions

View file

@ -12,20 +12,21 @@ from botocore.config import Config
from llama_models.llama3.api.chat_format import ChatFormat
from llama_models.llama3.api.tokenizer import Tokenizer
from llama_models.sku_list import resolve_model
from llama_stack.providers.utils.inference.routable import RoutableProviderForModels
from llama_stack.apis.inference import * # noqa: F403
from llama_stack.providers.adapters.inference.bedrock.config import BedrockConfig
# mapping of Model SKUs to ollama models
BEDROCK_SUPPORTED_MODELS = {
"Meta-Llama3.1-8B-Instruct": "meta.llama3-1-8b-instruct-v1:0",
"Meta-Llama3.1-70B-Instruct": "meta.llama3-1-70b-instruct-v1:0",
"Meta-Llama3.1-405B-Instruct": "meta.llama3-1-405b-instruct-v1:0",
"Llama3.1-8B-Instruct": "meta.llama3-1-8b-instruct-v1:0",
"Llama3.1-70B-Instruct": "meta.llama3-1-70b-instruct-v1:0",
"Llama3.1-405B-Instruct": "meta.llama3-1-405b-instruct-v1:0",
}
class BedrockInferenceAdapter(Inference):
class BedrockInferenceAdapter(Inference, RoutableProviderForModels):
@staticmethod
def _create_bedrock_client(config: BedrockConfig) -> BaseClient:
@ -68,6 +69,9 @@ class BedrockInferenceAdapter(Inference):
return boto3_session.client("bedrock-runtime", config=boto3_config)
def __init__(self, config: BedrockConfig) -> None:
RoutableProviderForModels.__init__(
self, stack_to_provider_models_map=BEDROCK_SUPPORTED_MODELS
)
self._config = config
self._client = BedrockInferenceAdapter._create_bedrock_client(config)
@ -94,22 +98,6 @@ class BedrockInferenceAdapter(Inference):
) -> Union[CompletionResponse, CompletionResponseStreamChunk]:
raise NotImplementedError()
@staticmethod
def resolve_bedrock_model(model_name: str) -> str:
model = resolve_model(model_name)
assert (
model is not None
and model.descriptor(shorten_default_variant=True)
in BEDROCK_SUPPORTED_MODELS
), (
f"Unsupported model: {model_name}, use one of the supported models: "
f"{','.join(BEDROCK_SUPPORTED_MODELS.keys())}"
)
return BEDROCK_SUPPORTED_MODELS.get(
model.descriptor(shorten_default_variant=True)
)
@staticmethod
def _bedrock_stop_reason_to_stop_reason(bedrock_stop_reason: str) -> StopReason:
if bedrock_stop_reason == "max_tokens":
@ -350,7 +338,7 @@ class BedrockInferenceAdapter(Inference):
) -> (
AsyncGenerator
): # Union[ChatCompletionResponse, ChatCompletionResponseStreamChunk]:
bedrock_model = BedrockInferenceAdapter.resolve_bedrock_model(model)
bedrock_model = self.map_to_provider_model(model)
inference_config = BedrockInferenceAdapter.get_bedrock_inference_config(
sampling_params
)

View file

@ -12,7 +12,8 @@ from llama_models.llama3.api.chat_format import ChatFormat
from llama_models.llama3.api.datatypes import Message, StopReason
from llama_models.llama3.api.tokenizer import Tokenizer
from llama_models.sku_list import resolve_model
from llama_stack.providers.utils.inference.routable import RoutableProviderForModels
from llama_stack.apis.inference import * # noqa: F403
from llama_stack.providers.utils.inference.augment_messages import (
@ -21,6 +22,7 @@ from llama_stack.providers.utils.inference.augment_messages import (
from .config import FireworksImplConfig
FIREWORKS_SUPPORTED_MODELS = {
"Llama3.1-8B-Instruct": "fireworks/llama-v3p1-8b-instruct",
"Llama3.1-70B-Instruct": "fireworks/llama-v3p1-70b-instruct",
@ -28,8 +30,11 @@ FIREWORKS_SUPPORTED_MODELS = {
}
class FireworksInferenceAdapter(Inference):
class FireworksInferenceAdapter(Inference, RoutableProviderForModels):
def __init__(self, config: FireworksImplConfig) -> None:
RoutableProviderForModels.__init__(
self, stack_to_provider_models_map=FIREWORKS_SUPPORTED_MODELS
)
self.config = config
tokenizer = Tokenizer.get_instance()
self.formatter = ChatFormat(tokenizer)
@ -65,18 +70,6 @@ class FireworksInferenceAdapter(Inference):
return fireworks_messages
def resolve_fireworks_model(self, model_name: str) -> str:
model = resolve_model(model_name)
assert (
model is not None
and model.descriptor(shorten_default_variant=True)
in FIREWORKS_SUPPORTED_MODELS
), f"Unsupported model: {model_name}, use one of the supported models: {','.join(FIREWORKS_SUPPORTED_MODELS.keys())}"
return FIREWORKS_SUPPORTED_MODELS.get(
model.descriptor(shorten_default_variant=True)
)
def get_fireworks_chat_options(self, request: ChatCompletionRequest) -> dict:
options = {}
if request.sampling_params is not None:
@ -112,7 +105,7 @@ class FireworksInferenceAdapter(Inference):
# accumulate sampling params and other options to pass to fireworks
options = self.get_fireworks_chat_options(request)
fireworks_model = self.resolve_fireworks_model(request.model)
fireworks_model = self.map_to_provider_model(request.model)
if not request.stream:
r = await self.client.chat.completions.acreate(

View file

@ -11,7 +11,6 @@ import httpx
from llama_models.llama3.api.chat_format import ChatFormat
from llama_models.llama3.api.datatypes import Message, StopReason
from llama_models.llama3.api.tokenizer import Tokenizer
from llama_models.sku_list import resolve_model
from ollama import AsyncClient
@ -19,6 +18,7 @@ from llama_stack.apis.inference import * # noqa: F403
from llama_stack.providers.utils.inference.augment_messages import (
augment_messages_for_tools,
)
from llama_stack.providers.utils.inference.routable import RoutableProviderForModels
# TODO: Eventually this will move to the llama cli model list command
# mapping of Model SKUs to ollama models
@ -29,8 +29,11 @@ OLLAMA_SUPPORTED_SKUS = {
}
class OllamaInferenceAdapter(Inference):
class OllamaInferenceAdapter(Inference, RoutableProviderForModels):
def __init__(self, url: str) -> None:
RoutableProviderForModels.__init__(
self, stack_to_provider_models_map=OLLAMA_SUPPORTED_SKUS
)
self.url = url
tokenizer = Tokenizer.get_instance()
self.formatter = ChatFormat(tokenizer)
@ -72,15 +75,6 @@ class OllamaInferenceAdapter(Inference):
return ollama_messages
def resolve_ollama_model(self, model_name: str) -> str:
model = resolve_model(model_name)
assert (
model is not None
and model.descriptor(shorten_default_variant=True) in OLLAMA_SUPPORTED_SKUS
), f"Unsupported model: {model_name}, use one of the supported models: {','.join(OLLAMA_SUPPORTED_SKUS.keys())}"
return OLLAMA_SUPPORTED_SKUS.get(model.descriptor(shorten_default_variant=True))
def get_ollama_chat_options(self, request: ChatCompletionRequest) -> dict:
options = {}
if request.sampling_params is not None:
@ -120,7 +114,7 @@ class OllamaInferenceAdapter(Inference):
messages = augment_messages_for_tools(request)
# accumulate sampling params and other options to pass to ollama
options = self.get_ollama_chat_options(request)
ollama_model = self.resolve_ollama_model(request.model)
ollama_model = self.map_to_provider_model(request.model)
res = await self.client.ps()
need_model_pull = True

View file

@ -13,6 +13,8 @@ from llama_models.llama3.api.chat_format import ChatFormat
from llama_models.llama3.api.datatypes import StopReason
from llama_models.llama3.api.tokenizer import Tokenizer
from llama_stack.distribution.datatypes import RoutableProvider
from llama_stack.apis.inference import * # noqa: F403
from llama_stack.providers.utils.inference.augment_messages import (
augment_messages_for_tools,
@ -23,7 +25,7 @@ from .config import InferenceAPIImplConfig, InferenceEndpointImplConfig, TGIImpl
logger = logging.getLogger(__name__)
class _HfAdapter(Inference):
class _HfAdapter(Inference, RoutableProvider):
client: AsyncInferenceClient
max_tokens: int
model_id: str
@ -32,6 +34,11 @@ class _HfAdapter(Inference):
self.tokenizer = Tokenizer.get_instance()
self.formatter = ChatFormat(self.tokenizer)
async def validate_routing_keys(self, routing_keys: list[str]) -> None:
# these are the model names the Llama Stack will use to route requests to this provider
# perform validation here if necessary
pass
async def shutdown(self) -> None:
pass

View file

@ -4,6 +4,8 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Optional
from llama_models.schema_utils import json_schema_type
from pydantic import BaseModel, Field
@ -14,7 +16,7 @@ class TogetherImplConfig(BaseModel):
default="https://api.together.xyz/v1",
description="The URL for the Together AI server",
)
api_key: str = Field(
default="",
api_key: Optional[str] = Field(
default=None,
description="The Together AI API Key",
)

View file

@ -10,7 +10,6 @@ from llama_models.llama3.api.chat_format import ChatFormat
from llama_models.llama3.api.datatypes import Message, StopReason
from llama_models.llama3.api.tokenizer import Tokenizer
from llama_models.sku_list import resolve_model
from together import Together
@ -19,9 +18,11 @@ from llama_stack.distribution.request_headers import NeedsRequestProviderData
from llama_stack.providers.utils.inference.augment_messages import (
augment_messages_for_tools,
)
from llama_stack.providers.utils.inference.routable import RoutableProviderForModels
from .config import TogetherImplConfig
TOGETHER_SUPPORTED_MODELS = {
"Llama3.1-8B-Instruct": "meta-llama/Meta-Llama-3.1-8B-Instruct-Turbo",
"Llama3.1-70B-Instruct": "meta-llama/Meta-Llama-3.1-70B-Instruct-Turbo",
@ -32,8 +33,13 @@ TOGETHER_SUPPORTED_MODELS = {
}
class TogetherInferenceAdapter(Inference, NeedsRequestProviderData):
class TogetherInferenceAdapter(
Inference, NeedsRequestProviderData, RoutableProviderForModels
):
def __init__(self, config: TogetherImplConfig) -> None:
RoutableProviderForModels.__init__(
self, stack_to_provider_models_map=TOGETHER_SUPPORTED_MODELS
)
self.config = config
tokenizer = Tokenizer.get_instance()
self.formatter = ChatFormat(tokenizer)
@ -69,18 +75,6 @@ class TogetherInferenceAdapter(Inference, NeedsRequestProviderData):
return together_messages
def resolve_together_model(self, model_name: str) -> str:
model = resolve_model(model_name)
assert (
model is not None
and model.descriptor(shorten_default_variant=True)
in TOGETHER_SUPPORTED_MODELS
), f"Unsupported model: {model_name}, use one of the supported models: {','.join(TOGETHER_SUPPORTED_MODELS.keys())}"
return TOGETHER_SUPPORTED_MODELS.get(
model.descriptor(shorten_default_variant=True)
)
def get_together_chat_options(self, request: ChatCompletionRequest) -> dict:
options = {}
if request.sampling_params is not None:
@ -103,12 +97,15 @@ class TogetherInferenceAdapter(Inference, NeedsRequestProviderData):
) -> AsyncGenerator:
together_api_key = None
provider_data = self.get_request_provider_data()
if provider_data is None or not provider_data.together_api_key:
raise ValueError(
'Pass Together API Key in the header X-LlamaStack-ProviderData as { "together_api_key": <your api key>}'
)
together_api_key = provider_data.together_api_key
if self.config.api_key is not None:
together_api_key = self.config.api_key
else:
provider_data = self.get_request_provider_data()
if provider_data is None or not provider_data.together_api_key:
raise ValueError(
'Pass Together API Key in the header X-LlamaStack-ProviderData as { "together_api_key": <your api key>}'
)
together_api_key = provider_data.together_api_key
client = Together(api_key=together_api_key)
# wrapper request to make it easier to pass around (internal only, not exposed to API)
@ -125,7 +122,7 @@ class TogetherInferenceAdapter(Inference, NeedsRequestProviderData):
# accumulate sampling params and other options to pass to together
options = self.get_together_chat_options(request)
together_model = self.resolve_together_model(request.model)
together_model = self.map_to_provider_model(request.model)
messages = augment_messages_for_tools(request)
if not request.stream:
@ -171,17 +168,10 @@ class TogetherInferenceAdapter(Inference, NeedsRequestProviderData):
stream=True,
**options,
):
if chunk.choices[0].finish_reason:
if (
stop_reason is None and chunk.choices[0].finish_reason == "stop"
) or (
stop_reason is None and chunk.choices[0].finish_reason == "eos"
):
if finish_reason := chunk.choices[0].finish_reason:
if stop_reason is None and finish_reason in ["stop", "eos"]:
stop_reason = StopReason.end_of_turn
elif (
stop_reason is None
and chunk.choices[0].finish_reason == "length"
):
elif stop_reason is None and finish_reason == "length":
stop_reason = StopReason.out_of_tokens
break

View file

@ -13,7 +13,7 @@ import chromadb
from numpy.typing import NDArray
from llama_stack.apis.memory import * # noqa: F403
from llama_stack.distribution.datatypes import RoutableProvider
from llama_stack.providers.utils.memory.vector_store import (
BankWithIndex,
@ -65,7 +65,7 @@ class ChromaIndex(EmbeddingIndex):
return QueryDocumentsResponse(chunks=chunks, scores=scores)
class ChromaMemoryAdapter(Memory):
class ChromaMemoryAdapter(Memory, RoutableProvider):
def __init__(self, url: str) -> None:
print(f"Initializing ChromaMemoryAdapter with url: {url}")
url = url.rstrip("/")
@ -93,6 +93,10 @@ class ChromaMemoryAdapter(Memory):
async def shutdown(self) -> None:
pass
async def validate_routing_keys(self, routing_keys: List[str]) -> None:
print(f"[chroma] Registering memory bank routing keys: {routing_keys}")
pass
async def create_memory_bank(
self,
name: str,

View file

@ -5,16 +5,17 @@
# the root directory of this source tree.
import uuid
from typing import List, Tuple
import psycopg2
from numpy.typing import NDArray
from psycopg2 import sql
from psycopg2.extras import execute_values, Json
from pydantic import BaseModel
from llama_stack.apis.memory import * # noqa: F403
from pydantic import BaseModel
from llama_stack.apis.memory import * # noqa: F403
from llama_stack.distribution.datatypes import RoutableProvider
from llama_stack.providers.utils.memory.vector_store import (
ALL_MINILM_L6_V2_DIMENSION,
@ -118,7 +119,7 @@ class PGVectorIndex(EmbeddingIndex):
return QueryDocumentsResponse(chunks=chunks, scores=scores)
class PGVectorMemoryAdapter(Memory):
class PGVectorMemoryAdapter(Memory, RoutableProvider):
def __init__(self, config: PGVectorConfig) -> None:
print(f"Initializing PGVectorMemoryAdapter -> {config.host}:{config.port}")
self.config = config
@ -160,6 +161,10 @@ class PGVectorMemoryAdapter(Memory):
async def shutdown(self) -> None:
pass
async def validate_routing_keys(self, routing_keys: List[str]) -> None:
print(f"[pgvector] Registering memory bank routing keys: {routing_keys}")
pass
async def create_memory_bank(
self,
name: str,

View file

@ -4,47 +4,58 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import json
import logging
import traceback
from typing import Any, Dict, List
from .config import BedrockSafetyConfig
import boto3
from llama_stack.apis.safety import * # noqa
from llama_models.llama3.api.datatypes import * # noqa: F403
import json
import logging
from llama_stack.distribution.datatypes import RoutableProvider
import boto3
from .config import BedrockSafetyConfig
logger = logging.getLogger(__name__)
class BedrockSafetyAdapter(Safety):
SUPPORTED_SHIELD_TYPES = [
"bedrock_guardrail",
]
class BedrockSafetyAdapter(Safety, RoutableProvider):
def __init__(self, config: BedrockSafetyConfig) -> None:
if not config.aws_profile:
raise ValueError(f"Missing boto_client aws_profile in model info::{config}")
self.config = config
async def initialize(self) -> None:
if not self.config.aws_profile:
raise RuntimeError(
f"Missing boto_client aws_profile in model info::{self.config}"
)
try:
print(f"initializing with profile --- > {self.config}::")
self.boto_client_profile = self.config.aws_profile
print(f"initializing with profile --- > {self.config}")
self.boto_client = boto3.Session(
profile_name=self.boto_client_profile
profile_name=self.config.aws_profile
).client("bedrock-runtime")
except Exception as e:
raise RuntimeError(f"Error initializing BedrockSafetyAdapter: {e}") from e
raise RuntimeError("Error initializing BedrockSafetyAdapter") from e
async def shutdown(self) -> None:
pass
async def validate_routing_keys(self, routing_keys: List[str]) -> None:
for key in routing_keys:
if key not in SUPPORTED_SHIELD_TYPES:
raise ValueError(f"Unknown safety shield type: {key}")
async def run_shield(
self, shield_type: str, messages: List[Message], params: Dict[str, Any] = None
) -> RunShieldResponse:
if shield_type not in SUPPORTED_SHIELD_TYPES:
raise ValueError(f"Unknown safety shield type: {shield_type}")
"""This is the implementation for the bedrock guardrails. The input to the guardrails is to be of this format
```content = [
{

View file

@ -3,7 +3,6 @@
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from llama_models.sku_list import resolve_model
from together import Together
from llama_models.llama3.api.datatypes import * # noqa: F403
@ -13,53 +12,52 @@ from llama_stack.apis.safety import (
SafetyViolation,
ViolationLevel,
)
from llama_stack.distribution.datatypes import RoutableProvider
from llama_stack.distribution.request_headers import NeedsRequestProviderData
from .config import TogetherSafetyConfig
SAFETY_SHIELD_TYPES = {
"llama_guard": "meta-llama/Meta-Llama-Guard-3-8B",
"Llama-Guard-3-8B": "meta-llama/Meta-Llama-Guard-3-8B",
"Llama-Guard-3-11B-Vision": "meta-llama/Llama-Guard-3-11B-Vision-Turbo",
}
def shield_type_to_model_name(shield_type: str) -> str:
if shield_type == "llama_guard":
shield_type = "Llama-Guard-3-8B"
model = resolve_model(shield_type)
if (
model is None
or not model.descriptor(shorten_default_variant=True) in SAFETY_SHIELD_TYPES
or model.model_family is not ModelFamily.safety
):
raise ValueError(
f"{shield_type} is not supported, please use of {','.join(SAFETY_SHIELD_TYPES.keys())}"
)
return SAFETY_SHIELD_TYPES.get(model.descriptor(shorten_default_variant=True))
class TogetherSafetyImpl(Safety, NeedsRequestProviderData):
class TogetherSafetyImpl(Safety, NeedsRequestProviderData, RoutableProvider):
def __init__(self, config: TogetherSafetyConfig) -> None:
self.config = config
async def initialize(self) -> None:
pass
async def shutdown(self) -> None:
pass
async def validate_routing_keys(self, routing_keys: List[str]) -> None:
for key in routing_keys:
if key not in SAFETY_SHIELD_TYPES:
raise ValueError(f"Unknown safety shield type: {key}")
async def run_shield(
self, shield_type: str, messages: List[Message], params: Dict[str, Any] = None
) -> RunShieldResponse:
if shield_type not in SAFETY_SHIELD_TYPES:
raise ValueError(f"Unknown safety shield type: {shield_type}")
together_api_key = None
provider_data = self.get_request_provider_data()
if provider_data is None or not provider_data.together_api_key:
raise ValueError(
'Pass Together API Key in the header X-LlamaStack-ProviderData as { "together_api_key": <your api key>}'
)
together_api_key = provider_data.together_api_key
if self.config.api_key is not None:
together_api_key = self.config.api_key
else:
provider_data = self.get_request_provider_data()
if provider_data is None or not provider_data.together_api_key:
raise ValueError(
'Pass Together API Key in the header X-LlamaStack-ProviderData as { "together_api_key": <your api key>}'
)
together_api_key = provider_data.together_api_key
model_name = shield_type_to_model_name(shield_type)
model_name = SAFETY_SHIELD_TYPES[shield_type]
# messages can have role assistant or user
api_messages = []