Another round of simplification and clarity for models/shields/memory_banks stuff

This commit is contained in:
Ashwin Bharambe 2024-10-09 19:19:26 -07:00
parent 73a0a34e39
commit b55034c0de
27 changed files with 454 additions and 444 deletions

View file

@ -121,3 +121,10 @@ class FireworksInferenceAdapter(ModelRegistryHelper, Inference):
"stream": request.stream,
**options,
}
async def embeddings(
self,
model: str,
contents: List[InterleavedTextMedia],
) -> EmbeddingsResponse:
raise NotImplementedError()

View file

@ -15,6 +15,7 @@ from llama_models.llama3.api.tokenizer import Tokenizer
from ollama import AsyncClient
from llama_stack.apis.inference import * # noqa: F403
from llama_stack.apis.models import * # noqa: F403
from llama_stack.providers.utils.inference.openai_compat import (
get_sampling_options,
OpenAICompatCompletionChoice,
@ -35,7 +36,7 @@ OLLAMA_SUPPORTED_MODELS = {
}
class OllamaInferenceAdapter(Inference):
class OllamaInferenceAdapter(Inference, Models):
def __init__(self, url: str) -> None:
self.url = url
self.formatter = ChatFormat(Tokenizer.get_instance())

View file

@ -6,14 +6,18 @@
import logging
from typing import AsyncGenerator
from typing import AsyncGenerator, List, Optional
from huggingface_hub import AsyncInferenceClient, HfApi
from llama_models.llama3.api.chat_format import ChatFormat
from llama_models.llama3.api.tokenizer import Tokenizer
from llama_models.sku_list import resolve_model
from llama_models.sku_list import all_registered_models
from llama_stack.apis.inference import * # noqa: F403
from llama_stack.apis.models import * # noqa: F403
from llama_stack.providers.datatypes import ModelDef, ModelsProtocolPrivate
from llama_stack.providers.utils.inference.openai_compat import (
get_sampling_options,
OpenAICompatCompletionChoice,
@ -30,26 +34,47 @@ from .config import InferenceAPIImplConfig, InferenceEndpointImplConfig, TGIImpl
logger = logging.getLogger(__name__)
class _HfAdapter(Inference):
class _HfAdapter(Inference, ModelsProtocolPrivate):
client: AsyncInferenceClient
max_tokens: int
model_id: str
def __init__(self) -> None:
self.formatter = ChatFormat(Tokenizer.get_instance())
self.huggingface_repo_to_llama_model_id = {
model.huggingface_repo: model.descriptor()
for model in all_registered_models()
if model.huggingface_repo
}
async def register_model(self, model: ModelDef) -> None:
resolved_model = resolve_model(model.identifier)
if resolved_model is None:
raise ValueError(f"Unknown model: {model.identifier}")
raise ValueError("Model registration is not supported for HuggingFace models")
if not resolved_model.huggingface_repo:
raise ValueError(
f"Model {model.identifier} does not have a HuggingFace repo"
async def list_models(self) -> List[ModelDef]:
repo = self.model_id
identifier = self.huggingface_repo_to_llama_model_id[repo]
return [
ModelDef(
identifier=identifier,
llama_model=identifier,
metadata={
"huggingface_repo": repo,
},
)
]
if self.model_id != resolved_model.huggingface_repo:
raise ValueError(f"Model mismatch: {model.identifier} != {self.model_id}")
async def get_model(self, identifier: str) -> Optional[ModelDef]:
model = self.huggingface_repo_to_llama_model_id.get(self.model_id)
if model != identifier:
return None
return ModelDef(
identifier=model,
llama_model=model,
metadata={
"huggingface_repo": self.model_id,
},
)
async def shutdown(self) -> None:
pass
@ -145,6 +170,13 @@ class _HfAdapter(Inference):
**options,
)
async def embeddings(
self,
model: str,
contents: List[InterleavedTextMedia],
) -> EmbeddingsResponse:
raise NotImplementedError()
class TGIAdapter(_HfAdapter):
async def initialize(self, config: TGIImplConfig) -> None:

View file

@ -134,3 +134,10 @@ class TogetherInferenceAdapter(
"stream": request.stream,
**get_sampling_options(request),
}
async def embeddings(
self,
model: str,
contents: List[InterleavedTextMedia],
) -> EmbeddingsResponse:
raise NotImplementedError()

View file

@ -10,6 +10,11 @@ from typing import Any, List, Optional, Protocol
from llama_models.schema_utils import json_schema_type
from pydantic import BaseModel, Field
from llama_stack.apis.memory_banks import MemoryBankDef
from llama_stack.apis.models import ModelDef
from llama_stack.apis.shields import ShieldDef
@json_schema_type
class Api(Enum):
@ -28,6 +33,30 @@ class Api(Enum):
inspect = "inspect"
class ModelsProtocolPrivate(Protocol):
async def list_models(self) -> List[ModelDef]: ...
async def get_model(self, identifier: str) -> Optional[ModelDef]: ...
async def register_model(self, model: ModelDef) -> None: ...
class ShieldsProtocolPrivate(Protocol):
async def list_shields(self) -> List[ShieldDef]: ...
async def get_shield(self, identifier: str) -> Optional[ShieldDef]: ...
async def register_shield(self, shield: ShieldDef) -> None: ...
class MemoryBanksProtocolPrivate(Protocol):
async def list_memory_banks(self) -> List[MemoryBankDef]: ...
async def get_memory_bank(self, identifier: str) -> Optional[MemoryBankDef]: ...
async def register_memory_bank(self, memory_bank: MemoryBankDef) -> None: ...
@json_schema_type
class ProviderSpec(BaseModel):
api: Api

View file

@ -12,6 +12,7 @@ from llama_models.sku_list import resolve_model
from llama_models.llama3.api.datatypes import * # noqa: F403
from llama_stack.apis.inference import * # noqa: F403
from llama_stack.providers.datatypes import ModelDef, ModelsProtocolPrivate
from llama_stack.providers.utils.inference.prompt_adapter import (
chat_completion_request_to_messages,
)
@ -24,7 +25,7 @@ from .model_parallel import LlamaModelParallelGenerator
SEMAPHORE = asyncio.Semaphore(1)
class MetaReferenceInferenceImpl(Inference):
class MetaReferenceInferenceImpl(Inference, ModelsProtocolPrivate):
def __init__(self, config: MetaReferenceImplConfig) -> None:
self.config = config
model = resolve_model(config.model)
@ -39,14 +40,38 @@ class MetaReferenceInferenceImpl(Inference):
self.generator.start()
async def register_model(self, model: ModelDef) -> None:
if model.identifier != self.model.descriptor():
raise RuntimeError(
f"Model mismatch: {model.identifier} != {self.model.descriptor()}"
raise ValueError("Dynamic model registration is not supported")
async def list_models(self) -> List[ModelDef]:
return [
ModelDef(
identifier=self.model.descriptor(),
llama_model=self.model.descriptor(),
)
]
async def get_model(self, identifier: str) -> Optional[ModelDef]:
if self.model.descriptor() != identifier:
return None
return ModelDef(
identifier=self.model.descriptor(),
llama_model=self.model.descriptor(),
)
async def shutdown(self) -> None:
self.generator.stop()
def completion(
self,
model: str,
content: InterleavedTextMedia,
sampling_params: Optional[SamplingParams] = SamplingParams(),
stream: Optional[bool] = False,
logprobs: Optional[LogProbConfig] = None,
) -> Union[CompletionResponse, CompletionResponseStreamChunk]:
raise NotImplementedError()
def chat_completion(
self,
model: str,
@ -255,3 +280,10 @@ class MetaReferenceInferenceImpl(Inference):
stop_reason=stop_reason,
)
)
async def embeddings(
self,
model: str,
contents: List[InterleavedTextMedia],
) -> EmbeddingsResponse:
raise NotImplementedError()

View file

@ -15,6 +15,8 @@ from numpy.typing import NDArray
from llama_models.llama3.api.datatypes import * # noqa: F403
from llama_stack.apis.memory import * # noqa: F403
from llama_stack.providers.datatypes import MemoryBanksProtocolPrivate
from llama_stack.providers.utils.memory.vector_store import (
ALL_MINILM_L6_V2_DIMENSION,
BankWithIndex,
@ -61,7 +63,7 @@ class FaissIndex(EmbeddingIndex):
return QueryDocumentsResponse(chunks=chunks, scores=scores)
class FaissMemoryImpl(Memory):
class FaissMemoryImpl(Memory, MemoryBanksProtocolPrivate):
def __init__(self, config: FaissImplConfig) -> None:
self.config = config
self.cache = {}
@ -83,6 +85,16 @@ class FaissMemoryImpl(Memory):
)
self.cache[memory_bank.identifier] = index
async def get_memory_bank(self, identifier: str) -> Optional[MemoryBankDef]:
banks = await self.list_memory_banks()
for bank in banks:
if bank.identifier == identifier:
return bank
return None
async def list_memory_banks(self) -> List[MemoryBankDef]:
return [i.bank for i in self.cache.values()]
async def insert_documents(
self,
bank_id: str,

View file

@ -28,7 +28,7 @@ from llama_stack.providers.tests.resolver import resolve_impls_for_test
# ```bash
# PROVIDER_ID=<your_provider> \
# PROVIDER_CONFIG=provider_config.yaml \
# pytest -s llama_stack/providers/tests/memory/test_inference.py \
# pytest -s llama_stack/providers/tests/inference/test_inference.py \
# --tb=short --disable-warnings
# ```
@ -56,7 +56,7 @@ def get_expected_stop_reason(model: str):
scope="session",
params=[
{"model": Llama_8B},
{"model": Llama_3B},
# {"model": Llama_3B},
],
ids=lambda d: d["model"],
)
@ -64,16 +64,11 @@ async def inference_settings(request):
model = request.param["model"]
impls = await resolve_impls_for_test(
Api.inference,
models=[
ModelDef(
identifier=model,
llama_model=model,
)
],
)
return {
"impl": impls[Api.inference],
"models_impl": impls[Api.models],
"common_params": {
"model": model,
"tool_choice": ToolChoice.auto,
@ -108,6 +103,25 @@ def sample_tool_definition():
)
@pytest.mark.asyncio
async def test_model_list(inference_settings):
params = inference_settings["common_params"]
models_impl = inference_settings["models_impl"]
response = await models_impl.list_models()
assert isinstance(response, list)
assert len(response) >= 1
assert all(isinstance(model, ModelDefWithProvider) for model in response)
model_def = None
for model in response:
if model.identifier == params["model"]:
model_def = model
break
assert model_def is not None
assert model_def.identifier == params["model"]
@pytest.mark.asyncio
async def test_chat_completion_non_streaming(inference_settings, sample_messages):
inference_impl = inference_settings["impl"]

View file

@ -3,6 +3,7 @@
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import os
import pytest
import pytest_asyncio
@ -30,12 +31,14 @@ from llama_stack.providers.tests.resolver import resolve_impls_for_test
@pytest_asyncio.fixture(scope="session")
async def memory_impl():
async def memory_settings():
impls = await resolve_impls_for_test(
Api.memory,
memory_banks=[],
)
return impls[Api.memory]
return {
"memory_impl": impls[Api.memory],
"memory_banks_impl": impls[Api.memory_banks],
}
@pytest.fixture
@ -64,23 +67,35 @@ def sample_documents():
]
async def register_memory_bank(memory_impl: Memory):
async def register_memory_bank(banks_impl: MemoryBanks):
bank = VectorMemoryBankDef(
identifier="test_bank",
embedding_model="all-MiniLM-L6-v2",
chunk_size_in_tokens=512,
overlap_size_in_tokens=64,
provider_id=os.environ["PROVIDER_ID"],
)
await memory_impl.register_memory_bank(bank)
await banks_impl.register_memory_bank(bank)
@pytest.mark.asyncio
async def test_query_documents(memory_impl, sample_documents):
async def test_banks_list(memory_settings):
banks_impl = memory_settings["memory_banks_impl"]
response = await banks_impl.list_memory_banks()
assert isinstance(response, list)
assert len(response) == 0
@pytest.mark.asyncio
async def test_query_documents(memory_settings, sample_documents):
memory_impl = memory_settings["memory_impl"]
banks_impl = memory_settings["memory_banks_impl"]
with pytest.raises(ValueError):
await memory_impl.insert_documents("test_bank", sample_documents)
await register_memory_bank(memory_impl)
await register_memory_bank(banks_impl)
await memory_impl.insert_documents("test_bank", sample_documents)
query1 = "programming language"

View file

@ -18,9 +18,6 @@ from llama_stack.distribution.resolver import resolve_impls_with_routing
async def resolve_impls_for_test(
api: Api,
models: List[ModelDef] = None,
memory_banks: List[MemoryBankDef] = None,
shields: List[ShieldDef] = None,
):
if "PROVIDER_CONFIG" not in os.environ:
raise ValueError(
@ -47,45 +44,11 @@ async def resolve_impls_for_test(
provider_id = provider["provider_id"]
print(f"No provider ID specified, picking first `{provider_id}`")
models = models or []
shields = shields or []
memory_banks = memory_banks or []
models = [
ModelDef(
**{
**m.dict(),
"provider_id": provider_id,
}
)
for m in models
]
shields = [
ShieldDef(
**{
**s.dict(),
"provider_id": provider_id,
}
)
for s in shields
]
memory_banks = [
MemoryBankDef(
**{
**m.dict(),
"provider_id": provider_id,
}
)
for m in memory_banks
]
run_config = dict(
built_at=datetime.now(),
image_name="test-fixture",
apis=[api],
providers={api.value: [Provider(**provider)]},
models=models,
memory_banks=memory_banks,
shields=shields,
)
run_config = parse_and_maybe_upgrade_config(run_config)
impls = await resolve_impls_with_routing(run_config)

View file

@ -4,14 +4,14 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Dict
from typing import Dict, List, Optional
from llama_models.sku_list import resolve_model
from llama_stack.apis.models import * # noqa: F403
from llama_stack.providers.datatypes import ModelDef, ModelsProtocolPrivate
class ModelRegistryHelper:
class ModelRegistryHelper(ModelsProtocolPrivate):
def __init__(self, stack_to_provider_models_map: Dict[str, str]):
self.stack_to_provider_models_map = stack_to_provider_models_map
@ -33,3 +33,15 @@ class ModelRegistryHelper:
raise ValueError(
f"Unsupported model {model.identifier}. Supported models: {self.stack_to_provider_models_map.keys()}"
)
async def list_models(self) -> List[ModelDef]:
models = []
for llama_model, provider_model in self.stack_to_provider_models_map.items():
models.append(ModelDef(identifier=llama_model, llama_model=llama_model))
return models
async def get_model(self, identifier: str) -> Optional[ModelDef]:
if identifier not in self.stack_to_provider_models_map:
return None
return ModelDef(identifier=identifier, llama_model=identifier)