Another round of simplification and clarity for models/shields/memory_banks stuff

This commit is contained in:
Ashwin Bharambe 2024-10-09 19:19:26 -07:00
parent 73a0a34e39
commit b55034c0de
27 changed files with 454 additions and 444 deletions

View file

@ -6,14 +6,18 @@
import logging
from typing import AsyncGenerator
from typing import AsyncGenerator, List, Optional
from huggingface_hub import AsyncInferenceClient, HfApi
from llama_models.llama3.api.chat_format import ChatFormat
from llama_models.llama3.api.tokenizer import Tokenizer
from llama_models.sku_list import resolve_model
from llama_models.sku_list import all_registered_models
from llama_stack.apis.inference import * # noqa: F403
from llama_stack.apis.models import * # noqa: F403
from llama_stack.providers.datatypes import ModelDef, ModelsProtocolPrivate
from llama_stack.providers.utils.inference.openai_compat import (
get_sampling_options,
OpenAICompatCompletionChoice,
@ -30,26 +34,47 @@ from .config import InferenceAPIImplConfig, InferenceEndpointImplConfig, TGIImpl
logger = logging.getLogger(__name__)
class _HfAdapter(Inference):
class _HfAdapter(Inference, ModelsProtocolPrivate):
client: AsyncInferenceClient
max_tokens: int
model_id: str
def __init__(self) -> None:
self.formatter = ChatFormat(Tokenizer.get_instance())
self.huggingface_repo_to_llama_model_id = {
model.huggingface_repo: model.descriptor()
for model in all_registered_models()
if model.huggingface_repo
}
async def register_model(self, model: ModelDef) -> None:
resolved_model = resolve_model(model.identifier)
if resolved_model is None:
raise ValueError(f"Unknown model: {model.identifier}")
raise ValueError("Model registration is not supported for HuggingFace models")
if not resolved_model.huggingface_repo:
raise ValueError(
f"Model {model.identifier} does not have a HuggingFace repo"
async def list_models(self) -> List[ModelDef]:
repo = self.model_id
identifier = self.huggingface_repo_to_llama_model_id[repo]
return [
ModelDef(
identifier=identifier,
llama_model=identifier,
metadata={
"huggingface_repo": repo,
},
)
]
if self.model_id != resolved_model.huggingface_repo:
raise ValueError(f"Model mismatch: {model.identifier} != {self.model_id}")
async def get_model(self, identifier: str) -> Optional[ModelDef]:
model = self.huggingface_repo_to_llama_model_id.get(self.model_id)
if model != identifier:
return None
return ModelDef(
identifier=model,
llama_model=model,
metadata={
"huggingface_repo": self.model_id,
},
)
async def shutdown(self) -> None:
pass
@ -145,6 +170,13 @@ class _HfAdapter(Inference):
**options,
)
async def embeddings(
self,
model: str,
contents: List[InterleavedTextMedia],
) -> EmbeddingsResponse:
raise NotImplementedError()
class TGIAdapter(_HfAdapter):
async def initialize(self, config: TGIImplConfig) -> None: