Introduce model_store, shield_store, memory_bank_store

This commit is contained in:
Ashwin Bharambe 2024-10-06 16:29:33 -07:00 committed by Ashwin Bharambe
parent e45a417543
commit 91e0063593
19 changed files with 172 additions and 297 deletions

View file

@ -6,39 +6,40 @@
from typing import AsyncGenerator
from openai import OpenAI
from llama_models.llama3.api.chat_format import ChatFormat
from llama_models.llama3.api.datatypes import Message, StopReason
from llama_models.llama3.api.tokenizer import Tokenizer
from llama_models.sku_list import resolve_model
from openai import OpenAI
from llama_stack.apis.inference import * # noqa: F403
from llama_stack.providers.utils.inference.augment_messages import (
augment_messages_for_tools,
)
from llama_stack.providers.utils.inference.model_registry import ModelRegistryHelper
from .config import DatabricksImplConfig
DATABRICKS_SUPPORTED_MODELS = {
"Llama3.1-70B-Instruct": "databricks-meta-llama-3-1-70b-instruct",
"Llama3.1-405B-Instruct": "databricks-meta-llama-3-1-405b-instruct",
}
class DatabricksInferenceAdapter(Inference):
class DatabricksInferenceAdapter(ModelRegistryHelper, Inference):
def __init__(self, config: DatabricksImplConfig) -> None:
ModelRegistryHelper.__init__(
self, stack_to_provider_models_map=DATABRICKS_SUPPORTED_MODELS
)
self.config = config
tokenizer = Tokenizer.get_instance()
self.formatter = ChatFormat(tokenizer)
@property
def client(self) -> OpenAI:
return OpenAI(
base_url=self.config.url,
api_key=self.config.api_token
)
return OpenAI(base_url=self.config.url, api_key=self.config.api_token)
async def initialize(self) -> None:
return
@ -65,18 +66,6 @@ class DatabricksInferenceAdapter(Inference):
return databricks_messages
def resolve_databricks_model(self, model_name: str) -> str:
model = resolve_model(model_name)
assert (
model is not None
and model.descriptor(shorten_default_variant=True)
in DATABRICKS_SUPPORTED_MODELS
), f"Unsupported model: {model_name}, use one of the supported models: {','.join(DATABRICKS_SUPPORTED_MODELS.keys())}"
return DATABRICKS_SUPPORTED_MODELS.get(
model.descriptor(shorten_default_variant=True)
)
def get_databricks_chat_options(self, request: ChatCompletionRequest) -> dict:
options = {}
if request.sampling_params is not None:
@ -110,10 +99,9 @@ class DatabricksInferenceAdapter(Inference):
messages = augment_messages_for_tools(request)
options = self.get_databricks_chat_options(request)
databricks_model = self.resolve_databricks_model(request.model)
databricks_model = self.map_to_provider_model(request.model)
if not request.stream:
r = self.client.chat.completions.create(
model=databricks_model,
messages=self._messages_to_databricks_messages(messages),
@ -154,10 +142,7 @@ class DatabricksInferenceAdapter(Inference):
**options,
):
if chunk.choices[0].finish_reason:
if (
stop_reason is None
and chunk.choices[0].finish_reason == "stop"
):
if stop_reason is None and chunk.choices[0].finish_reason == "stop":
stop_reason = StopReason.end_of_turn
elif (
stop_reason is None
@ -254,4 +239,4 @@ class DatabricksInferenceAdapter(Inference):
delta="",
stop_reason=stop_reason,
)
)
)

View file

@ -21,7 +21,7 @@ from llama_stack.providers.utils.inference.augment_messages import (
from llama_stack.providers.utils.inference.model_registry import ModelRegistryHelper
OLLAMA_SUPPORTED_SKUS = {
OLLAMA_SUPPORTED_MODELS = {
"Llama3.1-8B-Instruct": "llama3.1:8b-instruct-fp16",
"Llama3.1-70B-Instruct": "llama3.1:70b-instruct-fp16",
"Llama3.2-1B-Instruct": "llama3.2:1b-instruct-fp16",
@ -33,7 +33,7 @@ OLLAMA_SUPPORTED_SKUS = {
class OllamaInferenceAdapter(ModelRegistryHelper, Inference):
def __init__(self, url: str) -> None:
ModelRegistryHelper.__init__(
self, stack_to_provider_models_map=OLLAMA_SUPPORTED_SKUS
self, stack_to_provider_models_map=OLLAMA_SUPPORTED_MODELS
)
self.url = url
tokenizer = Tokenizer.get_instance()

View file

@ -9,14 +9,12 @@ from .config import SampleConfig
from llama_stack.apis.inference import * # noqa: F403
from llama_stack.distribution.datatypes import RoutableProvider
class SampleInferenceImpl(Inference, RoutableProvider):
class SampleInferenceImpl(Inference):
def __init__(self, config: SampleConfig):
self.config = config
async def validate_routing_keys(self, routing_keys: list[str]) -> None:
async def register_model(self, model: ModelDef) -> None:
# these are the model names the Llama Stack will use to route requests to this provider
# perform validation here if necessary
pass

View file

@ -12,6 +12,7 @@ from huggingface_hub import AsyncInferenceClient, HfApi
from llama_models.llama3.api.chat_format import ChatFormat
from llama_models.llama3.api.datatypes import StopReason
from llama_models.llama3.api.tokenizer import Tokenizer
from llama_models.sku_list import resolve_model
from llama_stack.apis.inference import * # noqa: F403
from llama_stack.providers.utils.inference.augment_messages import (
@ -32,16 +33,18 @@ class _HfAdapter(Inference):
self.tokenizer = Tokenizer.get_instance()
self.formatter = ChatFormat(self.tokenizer)
# TODO: make this work properly by checking this against the model_id being
# served by the remote endpoint
async def register_model(self, model: ModelDef) -> None:
pass
resolved_model = resolve_model(model.identifier)
if resolved_model is None:
raise ValueError(f"Unknown model: {model.identifier}")
async def list_models(self) -> List[ModelDef]:
return []
if not resolved_model.huggingface_repo:
raise ValueError(
f"Model {model.identifier} does not have a HuggingFace repo"
)
async def get_model(self, identifier: str) -> Optional[ModelDef]:
return None
if self.model_id != resolved_model.huggingface_repo:
raise ValueError(f"Model mismatch: {model.identifier} != {self.model_id}")
async def shutdown(self) -> None:
pass