Add provider deprecation support; change directory structure (#397)

* Add provider deprecation support; change directory structure

* fix a couple dangling imports

* move the meta_reference safety dir also
This commit is contained in:
Ashwin Bharambe 2024-11-07 13:04:53 -08:00 committed by GitHub
parent 36e2538eb0
commit 694c142b89
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
58 changed files with 61 additions and 120 deletions

View file

@ -8,6 +8,8 @@ import inspect
from typing import Any, Dict, List, Set from typing import Any, Dict, List, Set
from termcolor import cprint
from llama_stack.providers.datatypes import * # noqa: F403 from llama_stack.providers.datatypes import * # noqa: F403
from llama_stack.distribution.datatypes import * # noqa: F403 from llama_stack.distribution.datatypes import * # noqa: F403
@ -97,6 +99,12 @@ async def resolve_impls(
) )
p = provider_registry[api][provider.provider_type] p = provider_registry[api][provider.provider_type]
if p.deprecation_warning:
cprint(
f"Provider `{provider.provider_type}` for API `{api}` is deprecated and will be removed in a future release: {p.deprecation_warning}",
"red",
attrs=["bold"],
)
p.deps__ = [a.value for a in p.api_dependencies] p.deps__ = [a.value for a in p.api_dependencies]
spec = ProviderWithSpec( spec = ProviderWithSpec(
spec=p, spec=p,

View file

@ -82,6 +82,10 @@ class ProviderSpec(BaseModel):
default_factory=list, default_factory=list,
description="Higher-level API surfaces may depend on other providers to provide their functionality", description="Higher-level API surfaces may depend on other providers to provide their functionality",
) )
deprecation_warning: Optional[str] = Field(
default=None,
description="If this provider is deprecated, specify the warning message here",
)
# used internally by the resolver; this is a hack for now # used internally by the resolver; this is a hack for now
deps__: List[str] = Field(default_factory=list) deps__: List[str] = Field(default_factory=list)

View file

@ -4,10 +4,9 @@
# This source code is licensed under the terms described in the LICENSE file in # This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree. # the root directory of this source tree.
from pydantic import BaseModel, Field
from llama_stack.providers.utils.kvstore import KVStoreConfig from llama_stack.providers.utils.kvstore import KVStoreConfig
from llama_stack.providers.utils.kvstore.config import SqliteKVStoreConfig from llama_stack.providers.utils.kvstore.config import SqliteKVStoreConfig
from pydantic import BaseModel, Field
class MetaReferenceAgentsImplConfig(BaseModel): class MetaReferenceAgentsImplConfig(BaseModel):

View file

@ -11,9 +11,8 @@ from datetime import datetime
from typing import List, Optional from typing import List, Optional
from llama_stack.apis.agents import * # noqa: F403 from llama_stack.apis.agents import * # noqa: F403
from pydantic import BaseModel
from llama_stack.providers.utils.kvstore import KVStore from llama_stack.providers.utils.kvstore import KVStore
from pydantic import BaseModel
class AgentSessionInfo(BaseModel): class AgentSessionInfo(BaseModel):

View file

@ -10,14 +10,13 @@ from jinja2 import Template
from llama_models.llama3.api import * # noqa: F403 from llama_models.llama3.api import * # noqa: F403
from termcolor import cprint # noqa: F401
from llama_stack.apis.agents import ( from llama_stack.apis.agents import (
DefaultMemoryQueryGeneratorConfig, DefaultMemoryQueryGeneratorConfig,
LLMMemoryQueryGeneratorConfig, LLMMemoryQueryGeneratorConfig,
MemoryQueryGenerator, MemoryQueryGenerator,
MemoryQueryGeneratorConfig, MemoryQueryGeneratorConfig,
) )
from termcolor import cprint # noqa: F401
from llama_stack.apis.inference import * # noqa: F403 from llama_stack.apis.inference import * # noqa: F403

View file

@ -9,8 +9,7 @@ from typing import List
from llama_stack.apis.inference import Message from llama_stack.apis.inference import Message
from llama_stack.apis.safety import * # noqa: F403 from llama_stack.apis.safety import * # noqa: F403
from llama_stack.providers.inline.meta_reference.agents.safety import ShieldRunnerMixin from ..safety import ShieldRunnerMixin
from .builtin import BaseTool from .builtin import BaseTool

View file

@ -10,9 +10,8 @@ from llama_models.datatypes import * # noqa: F403
from llama_models.sku_list import resolve_model from llama_models.sku_list import resolve_model
from llama_stack.apis.inference import * # noqa: F401, F403 from llama_stack.apis.inference import * # noqa: F401, F403
from pydantic import BaseModel, Field, field_validator
from llama_stack.providers.utils.inference import supported_inference_models from llama_stack.providers.utils.inference import supported_inference_models
from pydantic import BaseModel, Field, field_validator
class MetaReferenceInferenceConfig(BaseModel): class MetaReferenceInferenceConfig(BaseModel):

View file

@ -35,13 +35,12 @@ from termcolor import cprint
from llama_stack.apis.inference import * # noqa: F403 from llama_stack.apis.inference import * # noqa: F403
from lmformatenforcer import JsonSchemaParser, TokenEnforcer, TokenEnforcerTokenizerData
from llama_stack.distribution.utils.model_utils import model_local_dir from llama_stack.distribution.utils.model_utils import model_local_dir
from llama_stack.providers.utils.inference.prompt_adapter import ( from llama_stack.providers.utils.inference.prompt_adapter import (
augment_content_with_response_format_prompt, augment_content_with_response_format_prompt,
chat_completion_request_to_messages, chat_completion_request_to_messages,
) )
from lmformatenforcer import JsonSchemaParser, TokenEnforcer, TokenEnforcerTokenizerData
from .config import ( from .config import (
Fp8QuantizationConfig, Fp8QuantizationConfig,

View file

@ -28,13 +28,13 @@ from fairscale.nn.model_parallel.initialize import (
get_model_parallel_src_rank, get_model_parallel_src_rank,
) )
from llama_stack.apis.inference import ChatCompletionRequest, CompletionRequest
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
from torch.distributed.launcher.api import elastic_launch, LaunchConfig from torch.distributed.launcher.api import elastic_launch, LaunchConfig
from typing_extensions import Annotated from typing_extensions import Annotated
from llama_stack.apis.inference import ChatCompletionRequest, CompletionRequest
from .generation import TokenResult from .generation import TokenResult

View file

@ -20,16 +20,15 @@ from llama_models.datatypes import CheckpointQuantizationFormat
from llama_models.llama3.api.args import ModelArgs from llama_models.llama3.api.args import ModelArgs
from llama_models.llama3.reference_impl.model import Transformer, TransformerBlock from llama_models.llama3.reference_impl.model import Transformer, TransformerBlock
from llama_models.sku_list import resolve_model from llama_models.sku_list import resolve_model
from llama_stack.apis.inference import QuantizationType
from termcolor import cprint from termcolor import cprint
from torch import nn, Tensor from torch import nn, Tensor
from torchao.quantization.GPTQ import Int8DynActInt4WeightLinear from torchao.quantization.GPTQ import Int8DynActInt4WeightLinear
from llama_stack.apis.inference import QuantizationType from ..config import MetaReferenceQuantizedInferenceConfig
from llama_stack.providers.inline.meta_reference.inference.config import (
MetaReferenceQuantizedInferenceConfig,
)
def swiglu_wrapper( def swiglu_wrapper(

View file

@ -5,9 +5,9 @@
# the root directory of this source tree. # the root directory of this source tree.
from llama_models.schema_utils import json_schema_type from llama_models.schema_utils import json_schema_type
from pydantic import BaseModel, Field, field_validator
from llama_stack.providers.utils.inference import supported_inference_models from llama_stack.providers.utils.inference import supported_inference_models
from pydantic import BaseModel, Field, field_validator
@json_schema_type @json_schema_type

View file

@ -5,13 +5,13 @@
# the root directory of this source tree. # the root directory of this source tree.
from llama_models.schema_utils import json_schema_type from llama_models.schema_utils import json_schema_type
from pydantic import BaseModel
from llama_stack.distribution.utils.config_dirs import RUNTIME_BASE_DIR from llama_stack.distribution.utils.config_dirs import RUNTIME_BASE_DIR
from llama_stack.providers.utils.kvstore.config import ( from llama_stack.providers.utils.kvstore.config import (
KVStoreConfig, KVStoreConfig,
SqliteKVStoreConfig, SqliteKVStoreConfig,
) )
from pydantic import BaseModel
@json_schema_type @json_schema_type

View file

@ -8,10 +8,11 @@ import logging
from typing import Any, Dict, List, Optional from typing import Any, Dict, List, Optional
import faiss
import numpy as np import numpy as np
from numpy.typing import NDArray from numpy.typing import NDArray
import faiss
from llama_models.llama3.api.datatypes import * # noqa: F403 from llama_models.llama3.api.datatypes import * # noqa: F403
from llama_stack.apis.memory import * # noqa: F403 from llama_stack.apis.memory import * # noqa: F403

View file

@ -1,73 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import tempfile
import pytest
from llama_stack.apis.memory import MemoryBankType, VectorMemoryBankDef
from llama_stack.providers.inline.meta_reference.memory.config import FaissImplConfig
from llama_stack.providers.inline.meta_reference.memory.faiss import FaissMemoryImpl
from llama_stack.providers.utils.kvstore.config import SqliteKVStoreConfig
class TestFaissMemoryImpl:
@pytest.fixture
def faiss_impl(self):
# Create a temporary SQLite database file
temp_db = tempfile.NamedTemporaryFile(suffix=".db", delete=False)
config = FaissImplConfig(kvstore=SqliteKVStoreConfig(db_path=temp_db.name))
return FaissMemoryImpl(config)
@pytest.mark.asyncio
async def test_initialize(self, faiss_impl):
# Test empty initialization
await faiss_impl.initialize()
assert len(faiss_impl.cache) == 0
# Test initialization with existing banks
bank = VectorMemoryBankDef(
identifier="test_bank",
type=MemoryBankType.vector.value,
embedding_model="all-MiniLM-L6-v2",
chunk_size_in_tokens=512,
overlap_size_in_tokens=64,
)
# Register a bank and reinitialize to test loading
await faiss_impl.register_memory_bank(bank)
# Create new instance to test initialization with existing data
new_impl = FaissMemoryImpl(faiss_impl.config)
await new_impl.initialize()
assert len(new_impl.cache) == 1
assert "test_bank" in new_impl.cache
@pytest.mark.asyncio
async def test_register_memory_bank(self, faiss_impl):
bank = VectorMemoryBankDef(
identifier="test_bank",
type=MemoryBankType.vector.value,
embedding_model="all-MiniLM-L6-v2",
chunk_size_in_tokens=512,
overlap_size_in_tokens=64,
)
await faiss_impl.initialize()
await faiss_impl.register_memory_bank(bank)
assert "test_bank" in faiss_impl.cache
assert faiss_impl.cache["test_bank"].bank == bank
# Verify persistence
new_impl = FaissMemoryImpl(faiss_impl.config)
await new_impl.initialize()
assert "test_bank" in new_impl.cache
if __name__ == "__main__":
pytest.main([__file__])

View file

@ -22,8 +22,8 @@ def available_providers() -> List[ProviderSpec]:
"scikit-learn", "scikit-learn",
] ]
+ kvstore_dependencies(), + kvstore_dependencies(),
module="llama_stack.providers.inline.meta_reference.agents", module="llama_stack.providers.inline.agents.meta_reference",
config_class="llama_stack.providers.inline.meta_reference.agents.MetaReferenceAgentsImplConfig", config_class="llama_stack.providers.inline.agents.meta_reference.MetaReferenceAgentsImplConfig",
api_dependencies=[ api_dependencies=[
Api.inference, Api.inference,
Api.safety, Api.safety,

View file

@ -27,8 +27,8 @@ def available_providers() -> List[ProviderSpec]:
api=Api.inference, api=Api.inference,
provider_type="meta-reference", provider_type="meta-reference",
pip_packages=META_REFERENCE_DEPS, pip_packages=META_REFERENCE_DEPS,
module="llama_stack.providers.inline.meta_reference.inference", module="llama_stack.providers.inline.inference.meta_reference",
config_class="llama_stack.providers.inline.meta_reference.inference.MetaReferenceInferenceConfig", config_class="llama_stack.providers.inline.inference.meta_reference.MetaReferenceInferenceConfig",
), ),
InlineProviderSpec( InlineProviderSpec(
api=Api.inference, api=Api.inference,
@ -40,8 +40,17 @@ def available_providers() -> List[ProviderSpec]:
"torchao==0.5.0", "torchao==0.5.0",
] ]
), ),
module="llama_stack.providers.inline.meta_reference.inference", module="llama_stack.providers.inline.inference.meta_reference",
config_class="llama_stack.providers.inline.meta_reference.inference.MetaReferenceQuantizedInferenceConfig", config_class="llama_stack.providers.inline.inference.meta_reference.MetaReferenceQuantizedInferenceConfig",
),
InlineProviderSpec(
api=Api.inference,
provider_type="vllm",
pip_packages=[
"vllm",
],
module="llama_stack.providers.inline.inference.vllm",
config_class="llama_stack.providers.inline.inference.vllm.VLLMConfig",
), ),
remote_provider_spec( remote_provider_spec(
api=Api.inference, api=Api.inference,
@ -140,13 +149,4 @@ def available_providers() -> List[ProviderSpec]:
config_class="llama_stack.providers.remote.inference.databricks.DatabricksImplConfig", config_class="llama_stack.providers.remote.inference.databricks.DatabricksImplConfig",
), ),
), ),
InlineProviderSpec(
api=Api.inference,
provider_type="vllm",
pip_packages=[
"vllm",
],
module="llama_stack.providers.inline.vllm",
config_class="llama_stack.providers.inline.vllm.VLLMConfig",
),
] ]

View file

@ -36,8 +36,16 @@ def available_providers() -> List[ProviderSpec]:
api=Api.memory, api=Api.memory,
provider_type="meta-reference", provider_type="meta-reference",
pip_packages=EMBEDDING_DEPS + ["faiss-cpu"], pip_packages=EMBEDDING_DEPS + ["faiss-cpu"],
module="llama_stack.providers.inline.meta_reference.memory", module="llama_stack.providers.inline.memory.faiss",
config_class="llama_stack.providers.inline.meta_reference.memory.FaissImplConfig", config_class="llama_stack.providers.inline.memory.faiss.FaissImplConfig",
deprecation_warning="Please use the `faiss` provider instead.",
),
InlineProviderSpec(
api=Api.memory,
provider_type="faiss",
pip_packages=EMBEDDING_DEPS + ["faiss-cpu"],
module="llama_stack.providers.inline.memory.faiss",
config_class="llama_stack.providers.inline.memory.faiss.FaissImplConfig",
), ),
remote_provider_spec( remote_provider_spec(
Api.memory, Api.memory,

View file

@ -24,8 +24,8 @@ def available_providers() -> List[ProviderSpec]:
"transformers", "transformers",
"torch --index-url https://download.pytorch.org/whl/cpu", "torch --index-url https://download.pytorch.org/whl/cpu",
], ],
module="llama_stack.providers.inline.meta_reference.safety", module="llama_stack.providers.inline.safety.meta_reference",
config_class="llama_stack.providers.inline.meta_reference.safety.SafetyConfig", config_class="llama_stack.providers.inline.safety.meta_reference.SafetyConfig",
api_dependencies=[ api_dependencies=[
Api.inference, Api.inference,
], ],
@ -54,8 +54,8 @@ def available_providers() -> List[ProviderSpec]:
pip_packages=[ pip_packages=[
"codeshield", "codeshield",
], ],
module="llama_stack.providers.inline.meta_reference.codeshield", module="llama_stack.providers.inline.safety.meta_reference",
config_class="llama_stack.providers.inline.meta_reference.codeshield.CodeShieldConfig", config_class="llama_stack.providers.inline.safety.meta_reference.CodeShieldConfig",
api_dependencies=[], api_dependencies=[],
), ),
] ]

View file

@ -11,7 +11,7 @@ import pytest_asyncio
from llama_stack.distribution.datatypes import Api, Provider from llama_stack.distribution.datatypes import Api, Provider
from llama_stack.providers.inline.meta_reference.agents import ( from llama_stack.providers.inline.agents.meta_reference import (
MetaReferenceAgentsImplConfig, MetaReferenceAgentsImplConfig,
) )

View file

@ -10,7 +10,7 @@ import pytest
import pytest_asyncio import pytest_asyncio
from llama_stack.distribution.datatypes import Api, Provider from llama_stack.distribution.datatypes import Api, Provider
from llama_stack.providers.inline.meta_reference.inference import ( from llama_stack.providers.inline.inference.meta_reference import (
MetaReferenceInferenceConfig, MetaReferenceInferenceConfig,
) )

View file

@ -11,7 +11,7 @@ import pytest
import pytest_asyncio import pytest_asyncio
from llama_stack.distribution.datatypes import Api, Provider from llama_stack.distribution.datatypes import Api, Provider
from llama_stack.providers.inline.meta_reference.memory import FaissImplConfig from llama_stack.providers.inline.memory.faiss import FaissImplConfig
from llama_stack.providers.remote.memory.pgvector import PGVectorConfig from llama_stack.providers.remote.memory.pgvector import PGVectorConfig
from llama_stack.providers.remote.memory.weaviate import WeaviateConfig from llama_stack.providers.remote.memory.weaviate import WeaviateConfig

View file

@ -8,7 +8,7 @@ import pytest
import pytest_asyncio import pytest_asyncio
from llama_stack.distribution.datatypes import Api, Provider from llama_stack.distribution.datatypes import Api, Provider
from llama_stack.providers.inline.meta_reference.safety import ( from llama_stack.providers.inline.safety.meta_reference import (
LlamaGuardShieldConfig, LlamaGuardShieldConfig,
SafetyConfig, SafetyConfig,
) )