mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-12-08 19:10:56 +00:00
Another round of simplification and clarity for models/shields/memory_banks stuff
This commit is contained in:
parent
73a0a34e39
commit
b55034c0de
27 changed files with 454 additions and 444 deletions
|
|
@ -121,3 +121,10 @@ class FireworksInferenceAdapter(ModelRegistryHelper, Inference):
|
|||
"stream": request.stream,
|
||||
**options,
|
||||
}
|
||||
|
||||
async def embeddings(
|
||||
self,
|
||||
model: str,
|
||||
contents: List[InterleavedTextMedia],
|
||||
) -> EmbeddingsResponse:
|
||||
raise NotImplementedError()
|
||||
|
|
|
|||
|
|
@ -15,6 +15,7 @@ from llama_models.llama3.api.tokenizer import Tokenizer
|
|||
from ollama import AsyncClient
|
||||
|
||||
from llama_stack.apis.inference import * # noqa: F403
|
||||
from llama_stack.apis.models import * # noqa: F403
|
||||
from llama_stack.providers.utils.inference.openai_compat import (
|
||||
get_sampling_options,
|
||||
OpenAICompatCompletionChoice,
|
||||
|
|
@ -35,7 +36,7 @@ OLLAMA_SUPPORTED_MODELS = {
|
|||
}
|
||||
|
||||
|
||||
class OllamaInferenceAdapter(Inference):
|
||||
class OllamaInferenceAdapter(Inference, Models):
|
||||
def __init__(self, url: str) -> None:
|
||||
self.url = url
|
||||
self.formatter = ChatFormat(Tokenizer.get_instance())
|
||||
|
|
|
|||
|
|
@ -6,14 +6,18 @@
|
|||
|
||||
|
||||
import logging
|
||||
from typing import AsyncGenerator
|
||||
from typing import AsyncGenerator, List, Optional
|
||||
|
||||
from huggingface_hub import AsyncInferenceClient, HfApi
|
||||
from llama_models.llama3.api.chat_format import ChatFormat
|
||||
from llama_models.llama3.api.tokenizer import Tokenizer
|
||||
from llama_models.sku_list import resolve_model
|
||||
from llama_models.sku_list import all_registered_models
|
||||
|
||||
from llama_stack.apis.inference import * # noqa: F403
|
||||
from llama_stack.apis.models import * # noqa: F403
|
||||
|
||||
from llama_stack.providers.datatypes import ModelDef, ModelsProtocolPrivate
|
||||
|
||||
from llama_stack.providers.utils.inference.openai_compat import (
|
||||
get_sampling_options,
|
||||
OpenAICompatCompletionChoice,
|
||||
|
|
@ -30,26 +34,47 @@ from .config import InferenceAPIImplConfig, InferenceEndpointImplConfig, TGIImpl
|
|||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class _HfAdapter(Inference):
|
||||
class _HfAdapter(Inference, ModelsProtocolPrivate):
|
||||
client: AsyncInferenceClient
|
||||
max_tokens: int
|
||||
model_id: str
|
||||
|
||||
def __init__(self) -> None:
|
||||
self.formatter = ChatFormat(Tokenizer.get_instance())
|
||||
self.huggingface_repo_to_llama_model_id = {
|
||||
model.huggingface_repo: model.descriptor()
|
||||
for model in all_registered_models()
|
||||
if model.huggingface_repo
|
||||
}
|
||||
|
||||
async def register_model(self, model: ModelDef) -> None:
|
||||
resolved_model = resolve_model(model.identifier)
|
||||
if resolved_model is None:
|
||||
raise ValueError(f"Unknown model: {model.identifier}")
|
||||
raise ValueError("Model registration is not supported for HuggingFace models")
|
||||
|
||||
if not resolved_model.huggingface_repo:
|
||||
raise ValueError(
|
||||
f"Model {model.identifier} does not have a HuggingFace repo"
|
||||
async def list_models(self) -> List[ModelDef]:
|
||||
repo = self.model_id
|
||||
identifier = self.huggingface_repo_to_llama_model_id[repo]
|
||||
return [
|
||||
ModelDef(
|
||||
identifier=identifier,
|
||||
llama_model=identifier,
|
||||
metadata={
|
||||
"huggingface_repo": repo,
|
||||
},
|
||||
)
|
||||
]
|
||||
|
||||
if self.model_id != resolved_model.huggingface_repo:
|
||||
raise ValueError(f"Model mismatch: {model.identifier} != {self.model_id}")
|
||||
async def get_model(self, identifier: str) -> Optional[ModelDef]:
|
||||
model = self.huggingface_repo_to_llama_model_id.get(self.model_id)
|
||||
if model != identifier:
|
||||
return None
|
||||
|
||||
return ModelDef(
|
||||
identifier=model,
|
||||
llama_model=model,
|
||||
metadata={
|
||||
"huggingface_repo": self.model_id,
|
||||
},
|
||||
)
|
||||
|
||||
async def shutdown(self) -> None:
|
||||
pass
|
||||
|
|
@ -145,6 +170,13 @@ class _HfAdapter(Inference):
|
|||
**options,
|
||||
)
|
||||
|
||||
async def embeddings(
|
||||
self,
|
||||
model: str,
|
||||
contents: List[InterleavedTextMedia],
|
||||
) -> EmbeddingsResponse:
|
||||
raise NotImplementedError()
|
||||
|
||||
|
||||
class TGIAdapter(_HfAdapter):
|
||||
async def initialize(self, config: TGIImplConfig) -> None:
|
||||
|
|
|
|||
|
|
@ -134,3 +134,10 @@ class TogetherInferenceAdapter(
|
|||
"stream": request.stream,
|
||||
**get_sampling_options(request),
|
||||
}
|
||||
|
||||
async def embeddings(
|
||||
self,
|
||||
model: str,
|
||||
contents: List[InterleavedTextMedia],
|
||||
) -> EmbeddingsResponse:
|
||||
raise NotImplementedError()
|
||||
|
|
|
|||
|
|
@ -10,6 +10,11 @@ from typing import Any, List, Optional, Protocol
|
|||
from llama_models.schema_utils import json_schema_type
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from llama_stack.apis.memory_banks import MemoryBankDef
|
||||
|
||||
from llama_stack.apis.models import ModelDef
|
||||
from llama_stack.apis.shields import ShieldDef
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class Api(Enum):
|
||||
|
|
@ -28,6 +33,30 @@ class Api(Enum):
|
|||
inspect = "inspect"
|
||||
|
||||
|
||||
class ModelsProtocolPrivate(Protocol):
|
||||
async def list_models(self) -> List[ModelDef]: ...
|
||||
|
||||
async def get_model(self, identifier: str) -> Optional[ModelDef]: ...
|
||||
|
||||
async def register_model(self, model: ModelDef) -> None: ...
|
||||
|
||||
|
||||
class ShieldsProtocolPrivate(Protocol):
|
||||
async def list_shields(self) -> List[ShieldDef]: ...
|
||||
|
||||
async def get_shield(self, identifier: str) -> Optional[ShieldDef]: ...
|
||||
|
||||
async def register_shield(self, shield: ShieldDef) -> None: ...
|
||||
|
||||
|
||||
class MemoryBanksProtocolPrivate(Protocol):
|
||||
async def list_memory_banks(self) -> List[MemoryBankDef]: ...
|
||||
|
||||
async def get_memory_bank(self, identifier: str) -> Optional[MemoryBankDef]: ...
|
||||
|
||||
async def register_memory_bank(self, memory_bank: MemoryBankDef) -> None: ...
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class ProviderSpec(BaseModel):
|
||||
api: Api
|
||||
|
|
|
|||
|
|
@ -12,6 +12,7 @@ from llama_models.sku_list import resolve_model
|
|||
|
||||
from llama_models.llama3.api.datatypes import * # noqa: F403
|
||||
from llama_stack.apis.inference import * # noqa: F403
|
||||
from llama_stack.providers.datatypes import ModelDef, ModelsProtocolPrivate
|
||||
from llama_stack.providers.utils.inference.prompt_adapter import (
|
||||
chat_completion_request_to_messages,
|
||||
)
|
||||
|
|
@ -24,7 +25,7 @@ from .model_parallel import LlamaModelParallelGenerator
|
|||
SEMAPHORE = asyncio.Semaphore(1)
|
||||
|
||||
|
||||
class MetaReferenceInferenceImpl(Inference):
|
||||
class MetaReferenceInferenceImpl(Inference, ModelsProtocolPrivate):
|
||||
def __init__(self, config: MetaReferenceImplConfig) -> None:
|
||||
self.config = config
|
||||
model = resolve_model(config.model)
|
||||
|
|
@ -39,14 +40,38 @@ class MetaReferenceInferenceImpl(Inference):
|
|||
self.generator.start()
|
||||
|
||||
async def register_model(self, model: ModelDef) -> None:
|
||||
if model.identifier != self.model.descriptor():
|
||||
raise RuntimeError(
|
||||
f"Model mismatch: {model.identifier} != {self.model.descriptor()}"
|
||||
raise ValueError("Dynamic model registration is not supported")
|
||||
|
||||
async def list_models(self) -> List[ModelDef]:
|
||||
return [
|
||||
ModelDef(
|
||||
identifier=self.model.descriptor(),
|
||||
llama_model=self.model.descriptor(),
|
||||
)
|
||||
]
|
||||
|
||||
async def get_model(self, identifier: str) -> Optional[ModelDef]:
|
||||
if self.model.descriptor() != identifier:
|
||||
return None
|
||||
|
||||
return ModelDef(
|
||||
identifier=self.model.descriptor(),
|
||||
llama_model=self.model.descriptor(),
|
||||
)
|
||||
|
||||
async def shutdown(self) -> None:
|
||||
self.generator.stop()
|
||||
|
||||
def completion(
|
||||
self,
|
||||
model: str,
|
||||
content: InterleavedTextMedia,
|
||||
sampling_params: Optional[SamplingParams] = SamplingParams(),
|
||||
stream: Optional[bool] = False,
|
||||
logprobs: Optional[LogProbConfig] = None,
|
||||
) -> Union[CompletionResponse, CompletionResponseStreamChunk]:
|
||||
raise NotImplementedError()
|
||||
|
||||
def chat_completion(
|
||||
self,
|
||||
model: str,
|
||||
|
|
@ -255,3 +280,10 @@ class MetaReferenceInferenceImpl(Inference):
|
|||
stop_reason=stop_reason,
|
||||
)
|
||||
)
|
||||
|
||||
async def embeddings(
|
||||
self,
|
||||
model: str,
|
||||
contents: List[InterleavedTextMedia],
|
||||
) -> EmbeddingsResponse:
|
||||
raise NotImplementedError()
|
||||
|
|
|
|||
|
|
@ -15,6 +15,8 @@ from numpy.typing import NDArray
|
|||
from llama_models.llama3.api.datatypes import * # noqa: F403
|
||||
|
||||
from llama_stack.apis.memory import * # noqa: F403
|
||||
from llama_stack.providers.datatypes import MemoryBanksProtocolPrivate
|
||||
|
||||
from llama_stack.providers.utils.memory.vector_store import (
|
||||
ALL_MINILM_L6_V2_DIMENSION,
|
||||
BankWithIndex,
|
||||
|
|
@ -61,7 +63,7 @@ class FaissIndex(EmbeddingIndex):
|
|||
return QueryDocumentsResponse(chunks=chunks, scores=scores)
|
||||
|
||||
|
||||
class FaissMemoryImpl(Memory):
|
||||
class FaissMemoryImpl(Memory, MemoryBanksProtocolPrivate):
|
||||
def __init__(self, config: FaissImplConfig) -> None:
|
||||
self.config = config
|
||||
self.cache = {}
|
||||
|
|
@ -83,6 +85,16 @@ class FaissMemoryImpl(Memory):
|
|||
)
|
||||
self.cache[memory_bank.identifier] = index
|
||||
|
||||
async def get_memory_bank(self, identifier: str) -> Optional[MemoryBankDef]:
|
||||
banks = await self.list_memory_banks()
|
||||
for bank in banks:
|
||||
if bank.identifier == identifier:
|
||||
return bank
|
||||
return None
|
||||
|
||||
async def list_memory_banks(self) -> List[MemoryBankDef]:
|
||||
return [i.bank for i in self.cache.values()]
|
||||
|
||||
async def insert_documents(
|
||||
self,
|
||||
bank_id: str,
|
||||
|
|
|
|||
|
|
@ -28,7 +28,7 @@ from llama_stack.providers.tests.resolver import resolve_impls_for_test
|
|||
# ```bash
|
||||
# PROVIDER_ID=<your_provider> \
|
||||
# PROVIDER_CONFIG=provider_config.yaml \
|
||||
# pytest -s llama_stack/providers/tests/memory/test_inference.py \
|
||||
# pytest -s llama_stack/providers/tests/inference/test_inference.py \
|
||||
# --tb=short --disable-warnings
|
||||
# ```
|
||||
|
||||
|
|
@ -56,7 +56,7 @@ def get_expected_stop_reason(model: str):
|
|||
scope="session",
|
||||
params=[
|
||||
{"model": Llama_8B},
|
||||
{"model": Llama_3B},
|
||||
# {"model": Llama_3B},
|
||||
],
|
||||
ids=lambda d: d["model"],
|
||||
)
|
||||
|
|
@ -64,16 +64,11 @@ async def inference_settings(request):
|
|||
model = request.param["model"]
|
||||
impls = await resolve_impls_for_test(
|
||||
Api.inference,
|
||||
models=[
|
||||
ModelDef(
|
||||
identifier=model,
|
||||
llama_model=model,
|
||||
)
|
||||
],
|
||||
)
|
||||
|
||||
return {
|
||||
"impl": impls[Api.inference],
|
||||
"models_impl": impls[Api.models],
|
||||
"common_params": {
|
||||
"model": model,
|
||||
"tool_choice": ToolChoice.auto,
|
||||
|
|
@ -108,6 +103,25 @@ def sample_tool_definition():
|
|||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_model_list(inference_settings):
|
||||
params = inference_settings["common_params"]
|
||||
models_impl = inference_settings["models_impl"]
|
||||
response = await models_impl.list_models()
|
||||
assert isinstance(response, list)
|
||||
assert len(response) >= 1
|
||||
assert all(isinstance(model, ModelDefWithProvider) for model in response)
|
||||
|
||||
model_def = None
|
||||
for model in response:
|
||||
if model.identifier == params["model"]:
|
||||
model_def = model
|
||||
break
|
||||
|
||||
assert model_def is not None
|
||||
assert model_def.identifier == params["model"]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_chat_completion_non_streaming(inference_settings, sample_messages):
|
||||
inference_impl = inference_settings["impl"]
|
||||
|
|
|
|||
|
|
@ -3,6 +3,7 @@
|
|||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
import os
|
||||
|
||||
import pytest
|
||||
import pytest_asyncio
|
||||
|
|
@ -30,12 +31,14 @@ from llama_stack.providers.tests.resolver import resolve_impls_for_test
|
|||
|
||||
|
||||
@pytest_asyncio.fixture(scope="session")
|
||||
async def memory_impl():
|
||||
async def memory_settings():
|
||||
impls = await resolve_impls_for_test(
|
||||
Api.memory,
|
||||
memory_banks=[],
|
||||
)
|
||||
return impls[Api.memory]
|
||||
return {
|
||||
"memory_impl": impls[Api.memory],
|
||||
"memory_banks_impl": impls[Api.memory_banks],
|
||||
}
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
|
|
@ -64,23 +67,35 @@ def sample_documents():
|
|||
]
|
||||
|
||||
|
||||
async def register_memory_bank(memory_impl: Memory):
|
||||
async def register_memory_bank(banks_impl: MemoryBanks):
|
||||
bank = VectorMemoryBankDef(
|
||||
identifier="test_bank",
|
||||
embedding_model="all-MiniLM-L6-v2",
|
||||
chunk_size_in_tokens=512,
|
||||
overlap_size_in_tokens=64,
|
||||
provider_id=os.environ["PROVIDER_ID"],
|
||||
)
|
||||
|
||||
await memory_impl.register_memory_bank(bank)
|
||||
await banks_impl.register_memory_bank(bank)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_query_documents(memory_impl, sample_documents):
|
||||
async def test_banks_list(memory_settings):
|
||||
banks_impl = memory_settings["memory_banks_impl"]
|
||||
response = await banks_impl.list_memory_banks()
|
||||
assert isinstance(response, list)
|
||||
assert len(response) == 0
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_query_documents(memory_settings, sample_documents):
|
||||
memory_impl = memory_settings["memory_impl"]
|
||||
banks_impl = memory_settings["memory_banks_impl"]
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
await memory_impl.insert_documents("test_bank", sample_documents)
|
||||
|
||||
await register_memory_bank(memory_impl)
|
||||
await register_memory_bank(banks_impl)
|
||||
await memory_impl.insert_documents("test_bank", sample_documents)
|
||||
|
||||
query1 = "programming language"
|
||||
|
|
|
|||
|
|
@ -18,9 +18,6 @@ from llama_stack.distribution.resolver import resolve_impls_with_routing
|
|||
|
||||
async def resolve_impls_for_test(
|
||||
api: Api,
|
||||
models: List[ModelDef] = None,
|
||||
memory_banks: List[MemoryBankDef] = None,
|
||||
shields: List[ShieldDef] = None,
|
||||
):
|
||||
if "PROVIDER_CONFIG" not in os.environ:
|
||||
raise ValueError(
|
||||
|
|
@ -47,45 +44,11 @@ async def resolve_impls_for_test(
|
|||
provider_id = provider["provider_id"]
|
||||
print(f"No provider ID specified, picking first `{provider_id}`")
|
||||
|
||||
models = models or []
|
||||
shields = shields or []
|
||||
memory_banks = memory_banks or []
|
||||
|
||||
models = [
|
||||
ModelDef(
|
||||
**{
|
||||
**m.dict(),
|
||||
"provider_id": provider_id,
|
||||
}
|
||||
)
|
||||
for m in models
|
||||
]
|
||||
shields = [
|
||||
ShieldDef(
|
||||
**{
|
||||
**s.dict(),
|
||||
"provider_id": provider_id,
|
||||
}
|
||||
)
|
||||
for s in shields
|
||||
]
|
||||
memory_banks = [
|
||||
MemoryBankDef(
|
||||
**{
|
||||
**m.dict(),
|
||||
"provider_id": provider_id,
|
||||
}
|
||||
)
|
||||
for m in memory_banks
|
||||
]
|
||||
run_config = dict(
|
||||
built_at=datetime.now(),
|
||||
image_name="test-fixture",
|
||||
apis=[api],
|
||||
providers={api.value: [Provider(**provider)]},
|
||||
models=models,
|
||||
memory_banks=memory_banks,
|
||||
shields=shields,
|
||||
)
|
||||
run_config = parse_and_maybe_upgrade_config(run_config)
|
||||
impls = await resolve_impls_with_routing(run_config)
|
||||
|
|
|
|||
|
|
@ -4,14 +4,14 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import Dict
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
from llama_models.sku_list import resolve_model
|
||||
|
||||
from llama_stack.apis.models import * # noqa: F403
|
||||
from llama_stack.providers.datatypes import ModelDef, ModelsProtocolPrivate
|
||||
|
||||
|
||||
class ModelRegistryHelper:
|
||||
class ModelRegistryHelper(ModelsProtocolPrivate):
|
||||
|
||||
def __init__(self, stack_to_provider_models_map: Dict[str, str]):
|
||||
self.stack_to_provider_models_map = stack_to_provider_models_map
|
||||
|
|
@ -33,3 +33,15 @@ class ModelRegistryHelper:
|
|||
raise ValueError(
|
||||
f"Unsupported model {model.identifier}. Supported models: {self.stack_to_provider_models_map.keys()}"
|
||||
)
|
||||
|
||||
async def list_models(self) -> List[ModelDef]:
|
||||
models = []
|
||||
for llama_model, provider_model in self.stack_to_provider_models_map.items():
|
||||
models.append(ModelDef(identifier=llama_model, llama_model=llama_model))
|
||||
return models
|
||||
|
||||
async def get_model(self, identifier: str) -> Optional[ModelDef]:
|
||||
if identifier not in self.stack_to_provider_models_map:
|
||||
return None
|
||||
|
||||
return ModelDef(identifier=identifier, llama_model=identifier)
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue