mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-12-17 11:42:36 +00:00
This PR adds the ability to generate embeddings in all supported inference providers. ``` pytest -v -s llama_stack/providers/tests/inference/test_embeddings.py -k "bedrock" --inference-model="amazon.titan-embed-text-v2:0" --env EMBEDDING_DIMENSION=1024 pytest -v -s -k "vllm" --inferrence-model="intfloat/e5-mistral-7b-instruct" llama_stack/providers/tests/inference/test_embeddings.py --env EMBEDDING_DIMENSION=4096 --env VLLM_URL="http://localhost:9798/v1" pytest -v -s --inference-model="nomic-ai/nomic-embed-text-v1.5" llama_stack/providers/tests/inference/test_embeddings.py -k "fireworks" --env FIREWORKS_API_KEY=<API_KEY>--env EMBEDDING_DIMENSION=128 pytest -v -s --inference-model="togethercomputer/m2-bert-80M-2k-retrieval" llama_stack/providers/tests/inference/test_embeddings.py -k "together" --env TOGETHER_API_KEY=<API_KEY>--env EMBEDDING_DIMENSION=768 pytest -v -s -k "ollama" --inference-model="all-minilm:v8" llama_stack/providers/tests/inference/test_embeddings.py --env EMBEDDING_DIMENSION=384 torchrun $CONDA_PREFIX/bin/pytest -v -s -k "meta_reference" --inference-model="sentence-transformers/all-MiniLM-L6-v2" llama_stack/providers/tests/inference/test_embeddings.py --env EMBEDDING_DIMENSION=384 ```
110 lines
4.3 KiB
Python
110 lines
4.3 KiB
Python
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
# All rights reserved.
|
|
#
|
|
# This source code is licensed under the terms described in the LICENSE file in
|
|
# the root directory of this source tree.
|
|
|
|
from typing import List
|
|
|
|
from llama_stack.distribution.datatypes import * # noqa: F403
|
|
|
|
|
|
EMBEDDING_DEPS = [
|
|
"blobfile",
|
|
"chardet",
|
|
"pypdf",
|
|
"tqdm",
|
|
"numpy",
|
|
"scikit-learn",
|
|
"scipy",
|
|
"nltk",
|
|
"sentencepiece",
|
|
"transformers",
|
|
# this happens to work because special dependencies are always installed last
|
|
# so if there was a regular torch installed first, this would be ignored
|
|
# we need a better way to do this to identify potential conflicts, etc.
|
|
# for now, this lets us significantly reduce the size of the container which
|
|
# does not have any "local" inference code (and hence does not need GPU-enabled torch)
|
|
"torch --index-url https://download.pytorch.org/whl/cpu",
|
|
"sentence-transformers --no-deps",
|
|
]
|
|
|
|
|
|
def available_providers() -> List[ProviderSpec]:
|
|
return [
|
|
InlineProviderSpec(
|
|
api=Api.memory,
|
|
provider_type="inline::meta-reference",
|
|
pip_packages=EMBEDDING_DEPS + ["faiss-cpu"],
|
|
module="llama_stack.providers.inline.memory.faiss",
|
|
config_class="llama_stack.providers.inline.memory.faiss.FaissImplConfig",
|
|
deprecation_warning="Please use the `inline::faiss` provider instead.",
|
|
api_dependencies=[Api.inference],
|
|
),
|
|
InlineProviderSpec(
|
|
api=Api.memory,
|
|
provider_type="inline::faiss",
|
|
pip_packages=EMBEDDING_DEPS + ["faiss-cpu"],
|
|
module="llama_stack.providers.inline.memory.faiss",
|
|
config_class="llama_stack.providers.inline.memory.faiss.FaissImplConfig",
|
|
api_dependencies=[Api.inference],
|
|
),
|
|
remote_provider_spec(
|
|
Api.memory,
|
|
AdapterSpec(
|
|
adapter_type="chromadb",
|
|
pip_packages=EMBEDDING_DEPS + ["chromadb-client"],
|
|
module="llama_stack.providers.remote.memory.chroma",
|
|
config_class="llama_stack.providers.remote.memory.chroma.ChromaRemoteImplConfig",
|
|
),
|
|
api_dependencies=[Api.inference],
|
|
),
|
|
InlineProviderSpec(
|
|
api=Api.memory,
|
|
provider_type="inline::chromadb",
|
|
pip_packages=EMBEDDING_DEPS + ["chromadb"],
|
|
module="llama_stack.providers.inline.memory.chroma",
|
|
config_class="llama_stack.providers.inline.memory.chroma.ChromaInlineImplConfig",
|
|
),
|
|
remote_provider_spec(
|
|
Api.memory,
|
|
AdapterSpec(
|
|
adapter_type="pgvector",
|
|
pip_packages=EMBEDDING_DEPS + ["psycopg2-binary"],
|
|
module="llama_stack.providers.remote.memory.pgvector",
|
|
config_class="llama_stack.providers.remote.memory.pgvector.PGVectorConfig",
|
|
),
|
|
api_dependencies=[Api.inference],
|
|
),
|
|
remote_provider_spec(
|
|
Api.memory,
|
|
AdapterSpec(
|
|
adapter_type="weaviate",
|
|
pip_packages=EMBEDDING_DEPS + ["weaviate-client"],
|
|
module="llama_stack.providers.remote.memory.weaviate",
|
|
config_class="llama_stack.providers.remote.memory.weaviate.WeaviateConfig",
|
|
provider_data_validator="llama_stack.providers.remote.memory.weaviate.WeaviateRequestProviderData",
|
|
),
|
|
api_dependencies=[Api.inference],
|
|
),
|
|
remote_provider_spec(
|
|
api=Api.memory,
|
|
adapter=AdapterSpec(
|
|
adapter_type="sample",
|
|
pip_packages=[],
|
|
module="llama_stack.providers.remote.memory.sample",
|
|
config_class="llama_stack.providers.remote.memory.sample.SampleConfig",
|
|
),
|
|
api_dependencies=[],
|
|
),
|
|
remote_provider_spec(
|
|
Api.memory,
|
|
AdapterSpec(
|
|
adapter_type="qdrant",
|
|
pip_packages=EMBEDDING_DEPS + ["qdrant-client"],
|
|
module="llama_stack.providers.remote.memory.qdrant",
|
|
config_class="llama_stack.providers.remote.memory.qdrant.QdrantConfig",
|
|
),
|
|
api_dependencies=[Api.inference],
|
|
),
|
|
]
|