mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-10-16 06:53:47 +00:00
Merge branch 'main' into eval_task_register
This commit is contained in:
commit
1b7e19d5d0
201 changed files with 1635 additions and 807 deletions
|
@ -1,16 +0,0 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class BedrockSafetyConfig(BaseModel):
|
||||
"""Configuration information for a guardrail that you want to use in the request."""
|
||||
|
||||
aws_profile: str = Field(
|
||||
default="default",
|
||||
description="The profile on the machine having valid aws credentials. This will ensure separation of creation to invocation",
|
||||
)
|
|
@ -145,11 +145,12 @@ Fully-qualified name of the module to import. The module is expected to have:
|
|||
|
||||
class RemoteProviderConfig(BaseModel):
|
||||
host: str = "localhost"
|
||||
port: int
|
||||
port: int = 0
|
||||
protocol: str = "http"
|
||||
|
||||
@property
|
||||
def url(self) -> str:
|
||||
return f"http://{self.host}:{self.port}"
|
||||
return f"{self.protocol}://{self.host}:{self.port}"
|
||||
|
||||
|
||||
@json_schema_type
|
||||
|
|
|
@ -16,7 +16,7 @@ from llama_stack.apis.datasets import * # noqa: F403
|
|||
from autoevals.llm import Factuality
|
||||
from autoevals.ragas import AnswerCorrectness
|
||||
from llama_stack.providers.datatypes import ScoringFunctionsProtocolPrivate
|
||||
from llama_stack.providers.impls.meta_reference.scoring.scoring_fn.common import (
|
||||
from llama_stack.providers.inline.meta_reference.scoring.scoring_fn.common import (
|
||||
aggregate_average,
|
||||
)
|
||||
|
|
@ -4,10 +4,11 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from pydantic import BaseModel
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from llama_stack.providers.utils.kvstore import KVStoreConfig
|
||||
from llama_stack.providers.utils.kvstore.config import SqliteKVStoreConfig
|
||||
|
||||
|
||||
class MetaReferenceAgentsImplConfig(BaseModel):
|
||||
persistence_store: KVStoreConfig
|
||||
persistence_store: KVStoreConfig = Field(default=SqliteKVStoreConfig())
|
|
@ -32,18 +32,18 @@ class ShieldRunnerMixin:
|
|||
self.output_shields = output_shields
|
||||
|
||||
async def run_multiple_shields(
|
||||
self, messages: List[Message], shield_types: List[str]
|
||||
self, messages: List[Message], identifiers: List[str]
|
||||
) -> None:
|
||||
responses = await asyncio.gather(
|
||||
*[
|
||||
self.safety_api.run_shield(
|
||||
shield_type=shield_type,
|
||||
identifier=identifier,
|
||||
messages=messages,
|
||||
)
|
||||
for shield_type in shield_types
|
||||
for identifier in identifiers
|
||||
]
|
||||
)
|
||||
for shield_type, response in zip(shield_types, responses):
|
||||
for identifier, response in zip(identifiers, responses):
|
||||
if not response.violation:
|
||||
continue
|
||||
|
||||
|
@ -52,6 +52,6 @@ class ShieldRunnerMixin:
|
|||
raise SafetyException(violation)
|
||||
elif violation.violation_level == ViolationLevel.WARN:
|
||||
cprint(
|
||||
f"[Warn]{shield_type} raised a warning",
|
||||
f"[Warn]{identifier} raised a warning",
|
||||
color="red",
|
||||
)
|
|
@ -9,7 +9,7 @@ from typing import List
|
|||
from llama_stack.apis.inference import Message
|
||||
from llama_stack.apis.safety import * # noqa: F403
|
||||
|
||||
from llama_stack.providers.impls.meta_reference.agents.safety import ShieldRunnerMixin
|
||||
from llama_stack.providers.inline.meta_reference.agents.safety import ShieldRunnerMixin
|
||||
|
||||
from .builtin import BaseTool
|
||||
|
|
@ -14,6 +14,11 @@ from llama_models.llama3.api.datatypes import * # noqa: F403
|
|||
from llama_stack.apis.inference import * # noqa: F403
|
||||
from llama_stack.providers.datatypes import ModelDef, ModelsProtocolPrivate
|
||||
|
||||
from llama_stack.providers.utils.inference.prompt_adapter import (
|
||||
convert_image_media_to_url,
|
||||
request_has_media,
|
||||
)
|
||||
|
||||
from .config import MetaReferenceInferenceConfig
|
||||
from .generation import Llama
|
||||
from .model_parallel import LlamaModelParallelGenerator
|
||||
|
@ -87,6 +92,7 @@ class MetaReferenceInferenceImpl(Inference, ModelsProtocolPrivate):
|
|||
logprobs=logprobs,
|
||||
)
|
||||
self.check_model(request)
|
||||
request = await request_with_localized_media(request)
|
||||
|
||||
if request.stream:
|
||||
return self._stream_completion(request)
|
||||
|
@ -211,6 +217,7 @@ class MetaReferenceInferenceImpl(Inference, ModelsProtocolPrivate):
|
|||
logprobs=logprobs,
|
||||
)
|
||||
self.check_model(request)
|
||||
request = await request_with_localized_media(request)
|
||||
|
||||
if self.config.create_distributed_process_group:
|
||||
if SEMAPHORE.locked():
|
||||
|
@ -388,3 +395,31 @@ class MetaReferenceInferenceImpl(Inference, ModelsProtocolPrivate):
|
|||
contents: List[InterleavedTextMedia],
|
||||
) -> EmbeddingsResponse:
|
||||
raise NotImplementedError()
|
||||
|
||||
|
||||
async def request_with_localized_media(
|
||||
request: Union[ChatCompletionRequest, CompletionRequest],
|
||||
) -> Union[ChatCompletionRequest, CompletionRequest]:
|
||||
if not request_has_media(request):
|
||||
return request
|
||||
|
||||
async def _convert_single_content(content):
|
||||
if isinstance(content, ImageMedia):
|
||||
url = await convert_image_media_to_url(content, download=True)
|
||||
return ImageMedia(image=URL(uri=url))
|
||||
else:
|
||||
return content
|
||||
|
||||
async def _convert_content(content):
|
||||
if isinstance(content, list):
|
||||
return [await _convert_single_content(c) for c in content]
|
||||
else:
|
||||
return await _convert_single_content(content)
|
||||
|
||||
if isinstance(request, ChatCompletionRequest):
|
||||
for m in request.messages:
|
||||
m.content = await _convert_content(m.content)
|
||||
else:
|
||||
request.content = await _convert_content(request.content)
|
||||
|
||||
return request
|
|
@ -27,7 +27,7 @@ from torchao.quantization.GPTQ import Int8DynActInt4WeightLinear
|
|||
|
||||
from llama_stack.apis.inference import QuantizationType
|
||||
|
||||
from llama_stack.providers.impls.meta_reference.inference.config import (
|
||||
from llama_stack.providers.inline.meta_reference.inference.config import (
|
||||
MetaReferenceQuantizedInferenceConfig,
|
||||
)
|
||||
|
21
llama_stack/providers/inline/meta_reference/memory/config.py
Normal file
21
llama_stack/providers/inline/meta_reference/memory/config.py
Normal file
|
@ -0,0 +1,21 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from llama_models.schema_utils import json_schema_type
|
||||
from pydantic import BaseModel
|
||||
|
||||
from llama_stack.distribution.utils.config_dirs import RUNTIME_BASE_DIR
|
||||
from llama_stack.providers.utils.kvstore.config import (
|
||||
KVStoreConfig,
|
||||
SqliteKVStoreConfig,
|
||||
)
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class FaissImplConfig(BaseModel):
|
||||
kvstore: KVStoreConfig = SqliteKVStoreConfig(
|
||||
db_path=(RUNTIME_BASE_DIR / "faiss_store.db").as_posix()
|
||||
) # Uses SQLite config specific to FAISS storage
|
|
@ -16,6 +16,7 @@ from llama_models.llama3.api.datatypes import * # noqa: F403
|
|||
|
||||
from llama_stack.apis.memory import * # noqa: F403
|
||||
from llama_stack.providers.datatypes import MemoryBanksProtocolPrivate
|
||||
from llama_stack.providers.utils.kvstore import kvstore_impl
|
||||
|
||||
from llama_stack.providers.utils.memory.vector_store import (
|
||||
ALL_MINILM_L6_V2_DIMENSION,
|
||||
|
@ -28,6 +29,8 @@ from .config import FaissImplConfig
|
|||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
MEMORY_BANKS_PREFIX = "memory_banks:"
|
||||
|
||||
|
||||
class FaissIndex(EmbeddingIndex):
|
||||
id_by_index: Dict[int, str]
|
||||
|
@ -69,10 +72,25 @@ class FaissMemoryImpl(Memory, MemoryBanksProtocolPrivate):
|
|||
def __init__(self, config: FaissImplConfig) -> None:
|
||||
self.config = config
|
||||
self.cache = {}
|
||||
self.kvstore = None
|
||||
|
||||
async def initialize(self) -> None: ...
|
||||
async def initialize(self) -> None:
|
||||
self.kvstore = await kvstore_impl(self.config.kvstore)
|
||||
# Load existing banks from kvstore
|
||||
start_key = MEMORY_BANKS_PREFIX
|
||||
end_key = f"{MEMORY_BANKS_PREFIX}\xff"
|
||||
stored_banks = await self.kvstore.range(start_key, end_key)
|
||||
|
||||
async def shutdown(self) -> None: ...
|
||||
for bank_data in stored_banks:
|
||||
bank = VectorMemoryBankDef.model_validate_json(bank_data)
|
||||
index = BankWithIndex(
|
||||
bank=bank, index=FaissIndex(ALL_MINILM_L6_V2_DIMENSION)
|
||||
)
|
||||
self.cache[bank.identifier] = index
|
||||
|
||||
async def shutdown(self) -> None:
|
||||
# Cleanup if needed
|
||||
pass
|
||||
|
||||
async def register_memory_bank(
|
||||
self,
|
||||
|
@ -82,6 +100,14 @@ class FaissMemoryImpl(Memory, MemoryBanksProtocolPrivate):
|
|||
memory_bank.type == MemoryBankType.vector.value
|
||||
), f"Only vector banks are supported {memory_bank.type}"
|
||||
|
||||
# Store in kvstore
|
||||
key = f"{MEMORY_BANKS_PREFIX}{memory_bank.identifier}"
|
||||
await self.kvstore.set(
|
||||
key=key,
|
||||
value=memory_bank.json(),
|
||||
)
|
||||
|
||||
# Store in cache
|
||||
index = BankWithIndex(
|
||||
bank=memory_bank, index=FaissIndex(ALL_MINILM_L6_V2_DIMENSION)
|
||||
)
|
|
@ -0,0 +1,73 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import tempfile
|
||||
|
||||
import pytest
|
||||
from llama_stack.apis.memory import MemoryBankType, VectorMemoryBankDef
|
||||
from llama_stack.providers.inline.meta_reference.memory.config import FaissImplConfig
|
||||
|
||||
from llama_stack.providers.inline.meta_reference.memory.faiss import FaissMemoryImpl
|
||||
from llama_stack.providers.utils.kvstore.config import SqliteKVStoreConfig
|
||||
|
||||
|
||||
class TestFaissMemoryImpl:
|
||||
@pytest.fixture
|
||||
def faiss_impl(self):
|
||||
# Create a temporary SQLite database file
|
||||
temp_db = tempfile.NamedTemporaryFile(suffix=".db", delete=False)
|
||||
config = FaissImplConfig(kvstore=SqliteKVStoreConfig(db_path=temp_db.name))
|
||||
return FaissMemoryImpl(config)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_initialize(self, faiss_impl):
|
||||
# Test empty initialization
|
||||
await faiss_impl.initialize()
|
||||
assert len(faiss_impl.cache) == 0
|
||||
|
||||
# Test initialization with existing banks
|
||||
bank = VectorMemoryBankDef(
|
||||
identifier="test_bank",
|
||||
type=MemoryBankType.vector.value,
|
||||
embedding_model="all-MiniLM-L6-v2",
|
||||
chunk_size_in_tokens=512,
|
||||
overlap_size_in_tokens=64,
|
||||
)
|
||||
|
||||
# Register a bank and reinitialize to test loading
|
||||
await faiss_impl.register_memory_bank(bank)
|
||||
|
||||
# Create new instance to test initialization with existing data
|
||||
new_impl = FaissMemoryImpl(faiss_impl.config)
|
||||
await new_impl.initialize()
|
||||
|
||||
assert len(new_impl.cache) == 1
|
||||
assert "test_bank" in new_impl.cache
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_register_memory_bank(self, faiss_impl):
|
||||
bank = VectorMemoryBankDef(
|
||||
identifier="test_bank",
|
||||
type=MemoryBankType.vector.value,
|
||||
embedding_model="all-MiniLM-L6-v2",
|
||||
chunk_size_in_tokens=512,
|
||||
overlap_size_in_tokens=64,
|
||||
)
|
||||
|
||||
await faiss_impl.initialize()
|
||||
await faiss_impl.register_memory_bank(bank)
|
||||
|
||||
assert "test_bank" in faiss_impl.cache
|
||||
assert faiss_impl.cache["test_bank"].bank == bank
|
||||
|
||||
# Verify persistence
|
||||
new_impl = FaissMemoryImpl(faiss_impl.config)
|
||||
await new_impl.initialize()
|
||||
assert "test_bank" in new_impl.cache
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
pytest.main([__file__])
|
|
@ -13,15 +13,15 @@ from llama_stack.apis.datasetio import * # noqa: F403
|
|||
from llama_stack.apis.datasets import * # noqa: F403
|
||||
from llama_stack.apis.inference.inference import Inference
|
||||
from llama_stack.providers.datatypes import ScoringFunctionsProtocolPrivate
|
||||
from llama_stack.providers.impls.meta_reference.scoring.scoring_fn.equality_scoring_fn import (
|
||||
from llama_stack.providers.inline.meta_reference.scoring.scoring_fn.equality_scoring_fn import (
|
||||
EqualityScoringFn,
|
||||
)
|
||||
|
||||
from llama_stack.providers.impls.meta_reference.scoring.scoring_fn.llm_as_judge_scoring_fn import (
|
||||
from llama_stack.providers.inline.meta_reference.scoring.scoring_fn.llm_as_judge_scoring_fn import (
|
||||
LlmAsJudgeScoringFn,
|
||||
)
|
||||
|
||||
from llama_stack.providers.impls.meta_reference.scoring.scoring_fn.subset_of_scoring_fn import (
|
||||
from llama_stack.providers.inline.meta_reference.scoring.scoring_fn.subset_of_scoring_fn import (
|
||||
SubsetOfScoringFn,
|
||||
)
|
||||
|
|
@ -4,18 +4,18 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from llama_stack.providers.impls.meta_reference.scoring.scoring_fn.base_scoring_fn import (
|
||||
from llama_stack.providers.inline.meta_reference.scoring.scoring_fn.base_scoring_fn import (
|
||||
BaseScoringFn,
|
||||
)
|
||||
from llama_stack.apis.scoring_functions import * # noqa: F401, F403
|
||||
from llama_stack.apis.scoring import * # noqa: F401, F403
|
||||
from llama_stack.apis.common.type_system import * # noqa: F403
|
||||
|
||||
from llama_stack.providers.impls.meta_reference.scoring.scoring_fn.common import (
|
||||
from llama_stack.providers.inline.meta_reference.scoring.scoring_fn.common import (
|
||||
aggregate_accuracy,
|
||||
)
|
||||
|
||||
from llama_stack.providers.impls.meta_reference.scoring.scoring_fn.fn_defs.equality import (
|
||||
from llama_stack.providers.inline.meta_reference.scoring.scoring_fn.fn_defs.equality import (
|
||||
equality,
|
||||
)
|
||||
|
|
@ -4,7 +4,7 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
from llama_stack.apis.inference.inference import Inference
|
||||
from llama_stack.providers.impls.meta_reference.scoring.scoring_fn.base_scoring_fn import (
|
||||
from llama_stack.providers.inline.meta_reference.scoring.scoring_fn.base_scoring_fn import (
|
||||
BaseScoringFn,
|
||||
)
|
||||
from llama_stack.apis.scoring_functions import * # noqa: F401, F403
|
||||
|
@ -12,10 +12,10 @@ from llama_stack.apis.scoring import * # noqa: F401, F403
|
|||
from llama_stack.apis.common.type_system import * # noqa: F403
|
||||
import re
|
||||
|
||||
from llama_stack.providers.impls.meta_reference.scoring.scoring_fn.common import (
|
||||
from llama_stack.providers.inline.meta_reference.scoring.scoring_fn.common import (
|
||||
aggregate_average,
|
||||
)
|
||||
from llama_stack.providers.impls.meta_reference.scoring.scoring_fn.fn_defs.llm_as_judge_8b_correctness import (
|
||||
from llama_stack.providers.inline.meta_reference.scoring.scoring_fn.fn_defs.llm_as_judge_8b_correctness import (
|
||||
llm_as_judge_8b_correctness,
|
||||
)
|
||||
|
|
@ -4,17 +4,17 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from llama_stack.providers.impls.meta_reference.scoring.scoring_fn.base_scoring_fn import (
|
||||
from llama_stack.providers.inline.meta_reference.scoring.scoring_fn.base_scoring_fn import (
|
||||
BaseScoringFn,
|
||||
)
|
||||
from llama_stack.apis.scoring_functions import * # noqa: F401, F403
|
||||
from llama_stack.apis.scoring import * # noqa: F401, F403
|
||||
from llama_stack.apis.common.type_system import * # noqa: F403
|
||||
from llama_stack.providers.impls.meta_reference.scoring.scoring_fn.common import (
|
||||
from llama_stack.providers.inline.meta_reference.scoring.scoring_fn.common import (
|
||||
aggregate_accuracy,
|
||||
)
|
||||
|
||||
from llama_stack.providers.impls.meta_reference.scoring.scoring_fn.fn_defs.subset_of import (
|
||||
from llama_stack.providers.inline.meta_reference.scoring.scoring_fn.fn_defs.subset_of import (
|
||||
subset_of,
|
||||
)
|
||||
|
|
@ -22,8 +22,8 @@ def available_providers() -> List[ProviderSpec]:
|
|||
"scikit-learn",
|
||||
]
|
||||
+ kvstore_dependencies(),
|
||||
module="llama_stack.providers.impls.meta_reference.agents",
|
||||
config_class="llama_stack.providers.impls.meta_reference.agents.MetaReferenceAgentsImplConfig",
|
||||
module="llama_stack.providers.inline.meta_reference.agents",
|
||||
config_class="llama_stack.providers.inline.meta_reference.agents.MetaReferenceAgentsImplConfig",
|
||||
api_dependencies=[
|
||||
Api.inference,
|
||||
Api.safety,
|
||||
|
@ -36,8 +36,8 @@ def available_providers() -> List[ProviderSpec]:
|
|||
adapter=AdapterSpec(
|
||||
adapter_type="sample",
|
||||
pip_packages=[],
|
||||
module="llama_stack.providers.adapters.agents.sample",
|
||||
config_class="llama_stack.providers.adapters.agents.sample.SampleConfig",
|
||||
module="llama_stack.providers.remote.agents.sample",
|
||||
config_class="llama_stack.providers.remote.agents.sample.SampleConfig",
|
||||
),
|
||||
),
|
||||
]
|
||||
|
|
|
@ -15,8 +15,8 @@ def available_providers() -> List[ProviderSpec]:
|
|||
api=Api.datasetio,
|
||||
provider_type="meta-reference",
|
||||
pip_packages=["pandas"],
|
||||
module="llama_stack.providers.impls.meta_reference.datasetio",
|
||||
config_class="llama_stack.providers.impls.meta_reference.datasetio.MetaReferenceDatasetIOConfig",
|
||||
module="llama_stack.providers.inline.meta_reference.datasetio",
|
||||
config_class="llama_stack.providers.inline.meta_reference.datasetio.MetaReferenceDatasetIOConfig",
|
||||
api_dependencies=[],
|
||||
),
|
||||
]
|
||||
|
|
|
@ -15,8 +15,8 @@ def available_providers() -> List[ProviderSpec]:
|
|||
api=Api.eval,
|
||||
provider_type="meta-reference",
|
||||
pip_packages=[],
|
||||
module="llama_stack.providers.impls.meta_reference.eval",
|
||||
config_class="llama_stack.providers.impls.meta_reference.eval.MetaReferenceEvalConfig",
|
||||
module="llama_stack.providers.inline.meta_reference.eval",
|
||||
config_class="llama_stack.providers.inline.meta_reference.eval.MetaReferenceEvalConfig",
|
||||
api_dependencies=[
|
||||
Api.datasetio,
|
||||
Api.datasets,
|
||||
|
|
|
@ -27,8 +27,8 @@ def available_providers() -> List[ProviderSpec]:
|
|||
api=Api.inference,
|
||||
provider_type="meta-reference",
|
||||
pip_packages=META_REFERENCE_DEPS,
|
||||
module="llama_stack.providers.impls.meta_reference.inference",
|
||||
config_class="llama_stack.providers.impls.meta_reference.inference.MetaReferenceInferenceConfig",
|
||||
module="llama_stack.providers.inline.meta_reference.inference",
|
||||
config_class="llama_stack.providers.inline.meta_reference.inference.MetaReferenceInferenceConfig",
|
||||
),
|
||||
InlineProviderSpec(
|
||||
api=Api.inference,
|
||||
|
@ -40,16 +40,16 @@ def available_providers() -> List[ProviderSpec]:
|
|||
"torchao==0.5.0",
|
||||
]
|
||||
),
|
||||
module="llama_stack.providers.impls.meta_reference.inference",
|
||||
config_class="llama_stack.providers.impls.meta_reference.inference.MetaReferenceQuantizedInferenceConfig",
|
||||
module="llama_stack.providers.inline.meta_reference.inference",
|
||||
config_class="llama_stack.providers.inline.meta_reference.inference.MetaReferenceQuantizedInferenceConfig",
|
||||
),
|
||||
remote_provider_spec(
|
||||
api=Api.inference,
|
||||
adapter=AdapterSpec(
|
||||
adapter_type="sample",
|
||||
pip_packages=[],
|
||||
module="llama_stack.providers.adapters.inference.sample",
|
||||
config_class="llama_stack.providers.adapters.inference.sample.SampleConfig",
|
||||
module="llama_stack.providers.remote.inference.sample",
|
||||
config_class="llama_stack.providers.remote.inference.sample.SampleConfig",
|
||||
),
|
||||
),
|
||||
remote_provider_spec(
|
||||
|
@ -57,26 +57,26 @@ def available_providers() -> List[ProviderSpec]:
|
|||
adapter=AdapterSpec(
|
||||
adapter_type="ollama",
|
||||
pip_packages=["ollama", "aiohttp"],
|
||||
config_class="llama_stack.providers.adapters.inference.ollama.OllamaImplConfig",
|
||||
module="llama_stack.providers.adapters.inference.ollama",
|
||||
config_class="llama_stack.providers.remote.inference.ollama.OllamaImplConfig",
|
||||
module="llama_stack.providers.remote.inference.ollama",
|
||||
),
|
||||
),
|
||||
remote_provider_spec(
|
||||
api=Api.inference,
|
||||
adapter=AdapterSpec(
|
||||
adapter_type="vllm",
|
||||
pip_packages=["openai"],
|
||||
module="llama_stack.providers.remote.inference.vllm",
|
||||
config_class="llama_stack.providers.remote.inference.vllm.VLLMInferenceAdapterConfig",
|
||||
),
|
||||
),
|
||||
# remote_provider_spec(
|
||||
# api=Api.inference,
|
||||
# adapter=AdapterSpec(
|
||||
# adapter_type="vllm",
|
||||
# pip_packages=["openai"],
|
||||
# module="llama_stack.providers.adapters.inference.vllm",
|
||||
# config_class="llama_stack.providers.adapters.inference.vllm.VLLMImplConfig",
|
||||
# ),
|
||||
# ),
|
||||
remote_provider_spec(
|
||||
api=Api.inference,
|
||||
adapter=AdapterSpec(
|
||||
adapter_type="tgi",
|
||||
pip_packages=["huggingface_hub", "aiohttp"],
|
||||
module="llama_stack.providers.adapters.inference.tgi",
|
||||
config_class="llama_stack.providers.adapters.inference.tgi.TGIImplConfig",
|
||||
module="llama_stack.providers.remote.inference.tgi",
|
||||
config_class="llama_stack.providers.remote.inference.tgi.TGIImplConfig",
|
||||
),
|
||||
),
|
||||
remote_provider_spec(
|
||||
|
@ -84,8 +84,8 @@ def available_providers() -> List[ProviderSpec]:
|
|||
adapter=AdapterSpec(
|
||||
adapter_type="hf::serverless",
|
||||
pip_packages=["huggingface_hub", "aiohttp"],
|
||||
module="llama_stack.providers.adapters.inference.tgi",
|
||||
config_class="llama_stack.providers.adapters.inference.tgi.InferenceAPIImplConfig",
|
||||
module="llama_stack.providers.remote.inference.tgi",
|
||||
config_class="llama_stack.providers.remote.inference.tgi.InferenceAPIImplConfig",
|
||||
),
|
||||
),
|
||||
remote_provider_spec(
|
||||
|
@ -93,8 +93,8 @@ def available_providers() -> List[ProviderSpec]:
|
|||
adapter=AdapterSpec(
|
||||
adapter_type="hf::endpoint",
|
||||
pip_packages=["huggingface_hub", "aiohttp"],
|
||||
module="llama_stack.providers.adapters.inference.tgi",
|
||||
config_class="llama_stack.providers.adapters.inference.tgi.InferenceEndpointImplConfig",
|
||||
module="llama_stack.providers.remote.inference.tgi",
|
||||
config_class="llama_stack.providers.remote.inference.tgi.InferenceEndpointImplConfig",
|
||||
),
|
||||
),
|
||||
remote_provider_spec(
|
||||
|
@ -104,8 +104,8 @@ def available_providers() -> List[ProviderSpec]:
|
|||
pip_packages=[
|
||||
"fireworks-ai",
|
||||
],
|
||||
module="llama_stack.providers.adapters.inference.fireworks",
|
||||
config_class="llama_stack.providers.adapters.inference.fireworks.FireworksImplConfig",
|
||||
module="llama_stack.providers.remote.inference.fireworks",
|
||||
config_class="llama_stack.providers.remote.inference.fireworks.FireworksImplConfig",
|
||||
),
|
||||
),
|
||||
remote_provider_spec(
|
||||
|
@ -115,9 +115,9 @@ def available_providers() -> List[ProviderSpec]:
|
|||
pip_packages=[
|
||||
"together",
|
||||
],
|
||||
module="llama_stack.providers.adapters.inference.together",
|
||||
config_class="llama_stack.providers.adapters.inference.together.TogetherImplConfig",
|
||||
provider_data_validator="llama_stack.providers.adapters.safety.together.TogetherProviderDataValidator",
|
||||
module="llama_stack.providers.remote.inference.together",
|
||||
config_class="llama_stack.providers.remote.inference.together.TogetherImplConfig",
|
||||
provider_data_validator="llama_stack.providers.remote.safety.together.TogetherProviderDataValidator",
|
||||
),
|
||||
),
|
||||
remote_provider_spec(
|
||||
|
@ -125,8 +125,8 @@ def available_providers() -> List[ProviderSpec]:
|
|||
adapter=AdapterSpec(
|
||||
adapter_type="bedrock",
|
||||
pip_packages=["boto3"],
|
||||
module="llama_stack.providers.adapters.inference.bedrock",
|
||||
config_class="llama_stack.providers.adapters.inference.bedrock.BedrockConfig",
|
||||
module="llama_stack.providers.remote.inference.bedrock",
|
||||
config_class="llama_stack.providers.remote.inference.bedrock.BedrockConfig",
|
||||
),
|
||||
),
|
||||
remote_provider_spec(
|
||||
|
@ -136,8 +136,8 @@ def available_providers() -> List[ProviderSpec]:
|
|||
pip_packages=[
|
||||
"openai",
|
||||
],
|
||||
module="llama_stack.providers.adapters.inference.databricks",
|
||||
config_class="llama_stack.providers.adapters.inference.databricks.DatabricksImplConfig",
|
||||
module="llama_stack.providers.remote.inference.databricks",
|
||||
config_class="llama_stack.providers.remote.inference.databricks.DatabricksImplConfig",
|
||||
),
|
||||
),
|
||||
InlineProviderSpec(
|
||||
|
@ -146,7 +146,7 @@ def available_providers() -> List[ProviderSpec]:
|
|||
pip_packages=[
|
||||
"vllm",
|
||||
],
|
||||
module="llama_stack.providers.impls.vllm",
|
||||
config_class="llama_stack.providers.impls.vllm.VLLMConfig",
|
||||
module="llama_stack.providers.inline.vllm",
|
||||
config_class="llama_stack.providers.inline.vllm.VLLMConfig",
|
||||
),
|
||||
]
|
||||
|
|
|
@ -36,15 +36,15 @@ def available_providers() -> List[ProviderSpec]:
|
|||
api=Api.memory,
|
||||
provider_type="meta-reference",
|
||||
pip_packages=EMBEDDING_DEPS + ["faiss-cpu"],
|
||||
module="llama_stack.providers.impls.meta_reference.memory",
|
||||
config_class="llama_stack.providers.impls.meta_reference.memory.FaissImplConfig",
|
||||
module="llama_stack.providers.inline.meta_reference.memory",
|
||||
config_class="llama_stack.providers.inline.meta_reference.memory.FaissImplConfig",
|
||||
),
|
||||
remote_provider_spec(
|
||||
Api.memory,
|
||||
AdapterSpec(
|
||||
adapter_type="chromadb",
|
||||
pip_packages=EMBEDDING_DEPS + ["chromadb-client"],
|
||||
module="llama_stack.providers.adapters.memory.chroma",
|
||||
module="llama_stack.providers.remote.memory.chroma",
|
||||
),
|
||||
),
|
||||
remote_provider_spec(
|
||||
|
@ -52,8 +52,8 @@ def available_providers() -> List[ProviderSpec]:
|
|||
AdapterSpec(
|
||||
adapter_type="pgvector",
|
||||
pip_packages=EMBEDDING_DEPS + ["psycopg2-binary"],
|
||||
module="llama_stack.providers.adapters.memory.pgvector",
|
||||
config_class="llama_stack.providers.adapters.memory.pgvector.PGVectorConfig",
|
||||
module="llama_stack.providers.remote.memory.pgvector",
|
||||
config_class="llama_stack.providers.remote.memory.pgvector.PGVectorConfig",
|
||||
),
|
||||
),
|
||||
remote_provider_spec(
|
||||
|
@ -61,9 +61,9 @@ def available_providers() -> List[ProviderSpec]:
|
|||
AdapterSpec(
|
||||
adapter_type="weaviate",
|
||||
pip_packages=EMBEDDING_DEPS + ["weaviate-client"],
|
||||
module="llama_stack.providers.adapters.memory.weaviate",
|
||||
config_class="llama_stack.providers.adapters.memory.weaviate.WeaviateConfig",
|
||||
provider_data_validator="llama_stack.providers.adapters.memory.weaviate.WeaviateRequestProviderData",
|
||||
module="llama_stack.providers.remote.memory.weaviate",
|
||||
config_class="llama_stack.providers.remote.memory.weaviate.WeaviateConfig",
|
||||
provider_data_validator="llama_stack.providers.remote.memory.weaviate.WeaviateRequestProviderData",
|
||||
),
|
||||
),
|
||||
remote_provider_spec(
|
||||
|
@ -71,8 +71,8 @@ def available_providers() -> List[ProviderSpec]:
|
|||
adapter=AdapterSpec(
|
||||
adapter_type="sample",
|
||||
pip_packages=[],
|
||||
module="llama_stack.providers.adapters.memory.sample",
|
||||
config_class="llama_stack.providers.adapters.memory.sample.SampleConfig",
|
||||
module="llama_stack.providers.remote.memory.sample",
|
||||
config_class="llama_stack.providers.remote.memory.sample.SampleConfig",
|
||||
),
|
||||
),
|
||||
remote_provider_spec(
|
||||
|
@ -80,8 +80,8 @@ def available_providers() -> List[ProviderSpec]:
|
|||
AdapterSpec(
|
||||
adapter_type="qdrant",
|
||||
pip_packages=EMBEDDING_DEPS + ["qdrant-client"],
|
||||
module="llama_stack.providers.adapters.memory.qdrant",
|
||||
config_class="llama_stack.providers.adapters.memory.qdrant.QdrantConfig",
|
||||
module="llama_stack.providers.remote.memory.qdrant",
|
||||
config_class="llama_stack.providers.remote.memory.qdrant.QdrantConfig",
|
||||
),
|
||||
),
|
||||
]
|
||||
|
|
|
@ -24,8 +24,8 @@ def available_providers() -> List[ProviderSpec]:
|
|||
"transformers",
|
||||
"torch --index-url https://download.pytorch.org/whl/cpu",
|
||||
],
|
||||
module="llama_stack.providers.impls.meta_reference.safety",
|
||||
config_class="llama_stack.providers.impls.meta_reference.safety.SafetyConfig",
|
||||
module="llama_stack.providers.inline.meta_reference.safety",
|
||||
config_class="llama_stack.providers.inline.meta_reference.safety.SafetyConfig",
|
||||
api_dependencies=[
|
||||
Api.inference,
|
||||
],
|
||||
|
@ -35,8 +35,8 @@ def available_providers() -> List[ProviderSpec]:
|
|||
adapter=AdapterSpec(
|
||||
adapter_type="sample",
|
||||
pip_packages=[],
|
||||
module="llama_stack.providers.adapters.safety.sample",
|
||||
config_class="llama_stack.providers.adapters.safety.sample.SampleConfig",
|
||||
module="llama_stack.providers.remote.safety.sample",
|
||||
config_class="llama_stack.providers.remote.safety.sample.SampleConfig",
|
||||
),
|
||||
),
|
||||
remote_provider_spec(
|
||||
|
@ -44,8 +44,8 @@ def available_providers() -> List[ProviderSpec]:
|
|||
adapter=AdapterSpec(
|
||||
adapter_type="bedrock",
|
||||
pip_packages=["boto3"],
|
||||
module="llama_stack.providers.adapters.safety.bedrock",
|
||||
config_class="llama_stack.providers.adapters.safety.bedrock.BedrockSafetyConfig",
|
||||
module="llama_stack.providers.remote.safety.bedrock",
|
||||
config_class="llama_stack.providers.remote.safety.bedrock.BedrockSafetyConfig",
|
||||
),
|
||||
),
|
||||
remote_provider_spec(
|
||||
|
@ -55,9 +55,9 @@ def available_providers() -> List[ProviderSpec]:
|
|||
pip_packages=[
|
||||
"together",
|
||||
],
|
||||
module="llama_stack.providers.adapters.safety.together",
|
||||
config_class="llama_stack.providers.adapters.safety.together.TogetherSafetyConfig",
|
||||
provider_data_validator="llama_stack.providers.adapters.safety.together.TogetherProviderDataValidator",
|
||||
module="llama_stack.providers.remote.safety.together",
|
||||
config_class="llama_stack.providers.remote.safety.together.TogetherSafetyConfig",
|
||||
provider_data_validator="llama_stack.providers.remote.safety.together.TogetherProviderDataValidator",
|
||||
),
|
||||
),
|
||||
InlineProviderSpec(
|
||||
|
@ -66,8 +66,8 @@ def available_providers() -> List[ProviderSpec]:
|
|||
pip_packages=[
|
||||
"codeshield",
|
||||
],
|
||||
module="llama_stack.providers.impls.meta_reference.codeshield",
|
||||
config_class="llama_stack.providers.impls.meta_reference.codeshield.CodeShieldConfig",
|
||||
module="llama_stack.providers.inline.meta_reference.codeshield",
|
||||
config_class="llama_stack.providers.inline.meta_reference.codeshield.CodeShieldConfig",
|
||||
api_dependencies=[],
|
||||
),
|
||||
]
|
||||
|
|
|
@ -15,8 +15,8 @@ def available_providers() -> List[ProviderSpec]:
|
|||
api=Api.scoring,
|
||||
provider_type="meta-reference",
|
||||
pip_packages=[],
|
||||
module="llama_stack.providers.impls.meta_reference.scoring",
|
||||
config_class="llama_stack.providers.impls.meta_reference.scoring.MetaReferenceScoringConfig",
|
||||
module="llama_stack.providers.inline.meta_reference.scoring",
|
||||
config_class="llama_stack.providers.inline.meta_reference.scoring.MetaReferenceScoringConfig",
|
||||
api_dependencies=[
|
||||
Api.datasetio,
|
||||
Api.datasets,
|
||||
|
@ -27,8 +27,8 @@ def available_providers() -> List[ProviderSpec]:
|
|||
api=Api.scoring,
|
||||
provider_type="braintrust",
|
||||
pip_packages=["autoevals", "openai"],
|
||||
module="llama_stack.providers.impls.braintrust.scoring",
|
||||
config_class="llama_stack.providers.impls.braintrust.scoring.BraintrustScoringConfig",
|
||||
module="llama_stack.providers.inline.braintrust.scoring",
|
||||
config_class="llama_stack.providers.inline.braintrust.scoring.BraintrustScoringConfig",
|
||||
api_dependencies=[
|
||||
Api.datasetio,
|
||||
Api.datasets,
|
||||
|
|
Some files were not shown because too many files have changed in this diff Show more
Loading…
Add table
Add a link
Reference in a new issue