mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-07-29 15:23:51 +00:00
Merge branch 'meta-llama:main' into main
This commit is contained in:
commit
3ee415dc35
16 changed files with 140 additions and 116 deletions
|
@ -190,7 +190,7 @@ class Inference(Protocol):
|
|||
messages: List[Message],
|
||||
sampling_params: Optional[SamplingParams] = SamplingParams(),
|
||||
# zero-shot tool definitions as input to the model
|
||||
tools: Optional[List[ToolDefinition]] = list,
|
||||
tools: Optional[List[ToolDefinition]] = None,
|
||||
tool_choice: Optional[ToolChoice] = ToolChoice.auto,
|
||||
tool_prompt_format: Optional[ToolPromptFormat] = ToolPromptFormat.json,
|
||||
stream: Optional[bool] = False,
|
||||
|
|
|
@ -13,9 +13,9 @@ from typing import Any, Dict, List, Optional
|
|||
|
||||
import fire
|
||||
import httpx
|
||||
from termcolor import cprint
|
||||
|
||||
from llama_stack.distribution.datatypes import RemoteProviderConfig
|
||||
from termcolor import cprint
|
||||
|
||||
from llama_stack.apis.memory import * # noqa: F403
|
||||
from llama_stack.providers.utils.memory.file_utils import data_url_from_file
|
||||
|
@ -120,7 +120,7 @@ async def run_main(host: str, port: int, stream: bool):
|
|||
name="test_bank",
|
||||
config=VectorMemoryBankConfig(
|
||||
bank_id="test_bank",
|
||||
embedding_model="dragon-roberta-query-2",
|
||||
embedding_model="all-MiniLM-L6-v2",
|
||||
chunk_size_in_tokens=512,
|
||||
overlap_size_in_tokens=64,
|
||||
),
|
||||
|
@ -129,7 +129,7 @@ async def run_main(host: str, port: int, stream: bool):
|
|||
|
||||
retrieved_bank = await client.get_memory_bank(bank.bank_id)
|
||||
assert retrieved_bank is not None
|
||||
assert retrieved_bank.config.embedding_model == "dragon-roberta-query-2"
|
||||
assert retrieved_bank.config.embedding_model == "all-MiniLM-L6-v2"
|
||||
|
||||
urls = [
|
||||
"memory_optimizations.rst",
|
||||
|
|
|
@ -9,12 +9,12 @@ import json
|
|||
|
||||
from llama_models.sku_list import resolve_model
|
||||
|
||||
from termcolor import colored
|
||||
|
||||
from llama_stack.cli.subcommand import Subcommand
|
||||
from llama_stack.cli.table import print_table
|
||||
from llama_stack.distribution.utils.serialize import EnumEncoder
|
||||
|
||||
from termcolor import colored
|
||||
|
||||
|
||||
class ModelDescribe(Subcommand):
|
||||
"""Show details about a model"""
|
||||
|
@ -52,7 +52,7 @@ class ModelDescribe(Subcommand):
|
|||
colored(model.descriptor(), "white", attrs=["bold"]),
|
||||
),
|
||||
("HuggingFace ID", model.huggingface_repo or "<Not Available>"),
|
||||
("Description", model.description_markdown),
|
||||
("Description", model.description),
|
||||
("Context Length", f"{model.max_seq_length // 1024}K tokens"),
|
||||
("Weights format", model.quantization_format.value),
|
||||
("Model params.json", json.dumps(model.arch_args, indent=4)),
|
||||
|
|
|
@ -66,6 +66,14 @@ def build_image(build_config: BuildConfig, build_file_path: Path):
|
|||
if provider_spec.docker_image:
|
||||
raise ValueError("A stack's dependencies cannot have a docker image")
|
||||
|
||||
special_deps = []
|
||||
deps = []
|
||||
for package in package_deps.pip_packages:
|
||||
if "--no-deps" in package or "--index-url" in package:
|
||||
special_deps.append(package)
|
||||
else:
|
||||
deps.append(package)
|
||||
|
||||
if build_config.image_type == ImageType.docker.value:
|
||||
script = pkg_resources.resource_filename(
|
||||
"llama_stack", "distribution/build_container.sh"
|
||||
|
@ -75,7 +83,7 @@ def build_image(build_config: BuildConfig, build_file_path: Path):
|
|||
build_config.name,
|
||||
package_deps.docker_image,
|
||||
str(build_file_path),
|
||||
" ".join(package_deps.pip_packages),
|
||||
" ".join(deps),
|
||||
]
|
||||
else:
|
||||
script = pkg_resources.resource_filename(
|
||||
|
@ -84,9 +92,12 @@ def build_image(build_config: BuildConfig, build_file_path: Path):
|
|||
args = [
|
||||
script,
|
||||
build_config.name,
|
||||
" ".join(package_deps.pip_packages),
|
||||
" ".join(deps),
|
||||
]
|
||||
|
||||
if special_deps:
|
||||
args.append("#".join(special_deps))
|
||||
|
||||
return_code = run_with_pty(args)
|
||||
if return_code != 0:
|
||||
cprint(
|
||||
|
|
|
@ -17,14 +17,16 @@ if [ -n "$LLAMA_MODELS_DIR" ]; then
|
|||
echo "Using llama-models-dir=$LLAMA_MODELS_DIR"
|
||||
fi
|
||||
|
||||
set -euo pipefail
|
||||
|
||||
if [ "$#" -ne 2 ]; then
|
||||
echo "Usage: $0 <distribution_type> <build_name> <pip_dependencies>" >&2
|
||||
if [ "$#" -lt 2 ]; then
|
||||
echo "Usage: $0 <distribution_type> <build_name> <pip_dependencies> [<special_pip_deps>]" >&2
|
||||
echo "Example: $0 <distribution_type> mybuild 'numpy pandas scipy'" >&2
|
||||
exit 1
|
||||
fi
|
||||
|
||||
special_pip_deps="$3"
|
||||
|
||||
set -euo pipefail
|
||||
|
||||
build_name="$1"
|
||||
env_name="llamastack-$build_name"
|
||||
pip_dependencies="$2"
|
||||
|
@ -43,6 +45,7 @@ source "$SCRIPT_DIR/common.sh"
|
|||
ensure_conda_env_python310() {
|
||||
local env_name="$1"
|
||||
local pip_dependencies="$2"
|
||||
local special_pip_deps="$3"
|
||||
local python_version="3.10"
|
||||
|
||||
# Check if conda command is available
|
||||
|
@ -78,7 +81,12 @@ ensure_conda_env_python310() {
|
|||
if [ -n "$TEST_PYPI_VERSION" ]; then
|
||||
# these packages are damaged in test-pypi, so install them first
|
||||
$CONDA_PREFIX/bin/pip install fastapi libcst
|
||||
$CONDA_PREFIX/bin/pip install --extra-index-url https://test.pypi.org/simple/ llama-models==$TEST_PYPI_VERSION llama-stack==$TEST_PYPI_VERSION $pip_dependencies
|
||||
$CONDA_PREFIX/bin/pip install --extra-index-url https://test.pypi.org/simple/ \
|
||||
llama-models==$TEST_PYPI_VERSION llama-stack==$TEST_PYPI_VERSION \
|
||||
$pip_dependencies
|
||||
if [ -n "$special_pip_deps" ]; then
|
||||
$CONDA_PREFIX/bin/pip install --no-deps "$special_pip_deps"
|
||||
fi
|
||||
else
|
||||
# Re-installing llama-stack in the new conda environment
|
||||
if [ -n "$LLAMA_STACK_DIR" ]; then
|
||||
|
@ -105,11 +113,16 @@ ensure_conda_env_python310() {
|
|||
fi
|
||||
|
||||
# Install pip dependencies
|
||||
if [ -n "$pip_dependencies" ]; then
|
||||
printf "Installing pip dependencies: $pip_dependencies\n"
|
||||
$CONDA_PREFIX/bin/pip install $pip_dependencies
|
||||
printf "Installing pip dependencies\n"
|
||||
$CONDA_PREFIX/bin/pip install $pip_dependencies
|
||||
if [ -n "$special_pip_deps" ]; then
|
||||
IFS='#' read -ra parts <<< "$special_pip_deps"
|
||||
for part in "${parts[@]}"; do
|
||||
echo "$part"
|
||||
$CONDA_PREFIX/bin/pip install $part
|
||||
done
|
||||
fi
|
||||
fi
|
||||
}
|
||||
|
||||
ensure_conda_env_python310 "$env_name" "$pip_dependencies"
|
||||
ensure_conda_env_python310 "$env_name" "$pip_dependencies" "$special_pip_deps"
|
||||
|
|
|
@ -4,12 +4,16 @@ LLAMA_MODELS_DIR=${LLAMA_MODELS_DIR:-}
|
|||
LLAMA_STACK_DIR=${LLAMA_STACK_DIR:-}
|
||||
TEST_PYPI_VERSION=${TEST_PYPI_VERSION:-}
|
||||
|
||||
if [ "$#" -ne 4 ]; then
|
||||
echo "Usage: $0 <build_name> <docker_base> <pip_dependencies>
|
||||
echo "Example: $0 my-fastapi-app python:3.9-slim 'fastapi uvicorn'
|
||||
if [ "$#" -lt 4 ]; then
|
||||
echo "Usage: $0 <build_name> <docker_base> <pip_dependencies> [<special_pip_deps>]" >&2
|
||||
echo "Example: $0 my-fastapi-app python:3.9-slim 'fastapi uvicorn' " >&2
|
||||
exit 1
|
||||
fi
|
||||
|
||||
special_pip_deps="$5"
|
||||
|
||||
set -euo pipefail
|
||||
|
||||
build_name="$1"
|
||||
image_name="llamastack-$build_name"
|
||||
docker_base=$2
|
||||
|
@ -21,8 +25,6 @@ RED='\033[0;31m'
|
|||
GREEN='\033[0;32m'
|
||||
NC='\033[0m' # No Color
|
||||
|
||||
set -euo pipefail
|
||||
|
||||
SCRIPT_DIR=$(dirname "$(readlink -f "$0")")
|
||||
REPO_DIR=$(dirname $(dirname "$SCRIPT_DIR"))
|
||||
DOCKER_BINARY=${DOCKER_BINARY:-docker}
|
||||
|
@ -85,6 +87,13 @@ if [ -n "$pip_dependencies" ]; then
|
|||
add_to_docker "RUN pip install $pip_dependencies"
|
||||
fi
|
||||
|
||||
if [ -n "$special_pip_deps" ]; then
|
||||
IFS='#' read -ra parts <<< "$special_pip_deps"
|
||||
for part in "${parts[@]}"; do
|
||||
add_to_docker "RUN pip install $part"
|
||||
done
|
||||
fi
|
||||
|
||||
add_to_docker <<EOF
|
||||
|
||||
# This would be good in production but for debugging flexibility lets not add it right now
|
||||
|
|
|
@ -103,8 +103,7 @@ class InferenceRouter(Inference):
|
|||
stream: Optional[bool] = False,
|
||||
logprobs: Optional[LogProbConfig] = None,
|
||||
) -> AsyncGenerator:
|
||||
# TODO: we need to fix streaming response to align provider implementations with Protocol.
|
||||
async for chunk in self.routing_table.get_provider_impl(model).chat_completion(
|
||||
params = dict(
|
||||
model=model,
|
||||
messages=messages,
|
||||
sampling_params=sampling_params,
|
||||
|
@ -113,6 +112,10 @@ class InferenceRouter(Inference):
|
|||
tool_prompt_format=tool_prompt_format,
|
||||
stream=stream,
|
||||
logprobs=logprobs,
|
||||
)
|
||||
# TODO: we need to fix streaming response to align provider implementations with Protocol.
|
||||
async for chunk in self.routing_table.get_provider_impl(model).chat_completion(
|
||||
**params
|
||||
):
|
||||
yield chunk
|
||||
|
||||
|
|
|
@ -33,8 +33,10 @@ class CommonRoutingTableImpl(RoutingTable):
|
|||
for p in self.providers.values():
|
||||
await p.shutdown()
|
||||
|
||||
def get_provider_impl(self, routing_key: str) -> Optional[Any]:
|
||||
return self.providers.get(routing_key)
|
||||
def get_provider_impl(self, routing_key: str) -> Any:
|
||||
if routing_key not in self.providers:
|
||||
raise ValueError(f"Could not find provider for {routing_key}")
|
||||
return self.providers[routing_key]
|
||||
|
||||
def get_routing_keys(self) -> List[str]:
|
||||
return self.routing_keys
|
||||
|
|
|
@ -368,17 +368,19 @@ async def resolve_impls_with_routing(run_config: StackRunConfig) -> Dict[Api, An
|
|||
providers = all_providers[info.router_api]
|
||||
|
||||
inner_specs = []
|
||||
inner_deps = []
|
||||
for rt_entry in routing_table:
|
||||
if rt_entry.provider_id not in providers:
|
||||
raise ValueError(
|
||||
f"Unknown provider `{rt_entry.provider_id}` is not available for API `{api}`"
|
||||
)
|
||||
inner_specs.append(providers[rt_entry.provider_id])
|
||||
inner_deps.extend(providers[rt_entry.provider_id].api_dependencies)
|
||||
|
||||
specs[source_api] = RoutingTableProviderSpec(
|
||||
api=source_api,
|
||||
module="llama_stack.distribution.routers",
|
||||
api_dependencies=[],
|
||||
api_dependencies=inner_deps,
|
||||
inner_specs=inner_specs,
|
||||
)
|
||||
configs[source_api] = routing_table
|
||||
|
|
|
@ -119,7 +119,7 @@ class TGIAdapter(Inference):
|
|||
)
|
||||
stop_reason = None
|
||||
if response.details.finish_reason:
|
||||
if response.details.finish_reason == "stop":
|
||||
if response.details.finish_reason in ["stop", "eos_token"]:
|
||||
stop_reason = StopReason.end_of_turn
|
||||
elif response.details.finish_reason == "length":
|
||||
stop_reason = StopReason.out_of_tokens
|
||||
|
|
|
@ -7,11 +7,11 @@
|
|||
from .config import SafetyConfig
|
||||
|
||||
|
||||
async def get_provider_impl(config: SafetyConfig, _deps):
|
||||
async def get_provider_impl(config: SafetyConfig, deps):
|
||||
from .safety import MetaReferenceSafetyImpl
|
||||
|
||||
assert isinstance(config, SafetyConfig), f"Unexpected config type: {type(config)}"
|
||||
|
||||
impl = MetaReferenceSafetyImpl(config)
|
||||
impl = MetaReferenceSafetyImpl(config, deps)
|
||||
await impl.initialize()
|
||||
return impl
|
||||
|
|
|
@ -7,8 +7,10 @@
|
|||
from llama_models.sku_list import resolve_model
|
||||
|
||||
from llama_stack.distribution.utils.model_utils import model_local_dir
|
||||
from llama_stack.apis.inference import * # noqa: F403
|
||||
from llama_stack.apis.safety import * # noqa: F403
|
||||
from llama_models.llama3.api.datatypes import * # noqa: F403
|
||||
from llama_stack.distribution.datatypes import Api
|
||||
|
||||
from llama_stack.providers.impls.meta_reference.safety.shields.base import (
|
||||
OnViolationAction,
|
||||
|
@ -34,20 +36,11 @@ def resolve_and_get_path(model_name: str) -> str:
|
|||
|
||||
|
||||
class MetaReferenceSafetyImpl(Safety):
|
||||
def __init__(self, config: SafetyConfig) -> None:
|
||||
def __init__(self, config: SafetyConfig, deps) -> None:
|
||||
self.config = config
|
||||
self.inference_api = deps[Api.inference]
|
||||
|
||||
async def initialize(self) -> None:
|
||||
shield_cfg = self.config.llama_guard_shield
|
||||
if shield_cfg is not None:
|
||||
model_dir = resolve_and_get_path(shield_cfg.model)
|
||||
_ = LlamaGuardShield.instance(
|
||||
model_dir=model_dir,
|
||||
excluded_categories=shield_cfg.excluded_categories,
|
||||
disable_input_check=shield_cfg.disable_input_check,
|
||||
disable_output_check=shield_cfg.disable_output_check,
|
||||
)
|
||||
|
||||
shield_cfg = self.config.prompt_guard_shield
|
||||
if shield_cfg is not None:
|
||||
model_dir = resolve_and_get_path(shield_cfg.model)
|
||||
|
@ -91,11 +84,18 @@ class MetaReferenceSafetyImpl(Safety):
|
|||
def get_shield_impl(self, typ: MetaReferenceShieldType) -> ShieldBase:
|
||||
cfg = self.config
|
||||
if typ == MetaReferenceShieldType.llama_guard:
|
||||
cfg = cfg.llama_guard_shield
|
||||
assert (
|
||||
cfg.llama_guard_shield is not None
|
||||
cfg is not None
|
||||
), "Cannot use LlamaGuardShield since not present in config"
|
||||
model_dir = resolve_and_get_path(cfg.llama_guard_shield.model)
|
||||
return LlamaGuardShield.instance(model_dir=model_dir)
|
||||
|
||||
return LlamaGuardShield(
|
||||
model=cfg.model,
|
||||
inference_api=self.inference_api,
|
||||
excluded_categories=cfg.excluded_categories,
|
||||
disable_input_check=cfg.disable_input_check,
|
||||
disable_output_check=cfg.disable_output_check,
|
||||
)
|
||||
elif typ == MetaReferenceShieldType.jailbreak_shield:
|
||||
assert (
|
||||
cfg.prompt_guard_shield is not None
|
||||
|
|
|
@ -9,9 +9,8 @@ import re
|
|||
from string import Template
|
||||
from typing import List, Optional
|
||||
|
||||
import torch
|
||||
from llama_models.llama3.api.datatypes import Message, Role
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||
from llama_stack.apis.inference import * # noqa: F403
|
||||
|
||||
from .base import CANNED_RESPONSE_TEXT, OnViolationAction, ShieldBase, ShieldResponse
|
||||
|
||||
|
@ -100,39 +99,17 @@ PROMPT_TEMPLATE = Template(
|
|||
|
||||
|
||||
class LlamaGuardShield(ShieldBase):
|
||||
@staticmethod
|
||||
def instance(
|
||||
on_violation_action=OnViolationAction.RAISE,
|
||||
model_dir: str = None,
|
||||
excluded_categories: List[str] = None,
|
||||
disable_input_check: bool = False,
|
||||
disable_output_check: bool = False,
|
||||
) -> "LlamaGuardShield":
|
||||
global _INSTANCE
|
||||
if _INSTANCE is None:
|
||||
_INSTANCE = LlamaGuardShield(
|
||||
on_violation_action,
|
||||
model_dir,
|
||||
excluded_categories,
|
||||
disable_input_check,
|
||||
disable_output_check,
|
||||
)
|
||||
return _INSTANCE
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
on_violation_action: OnViolationAction = OnViolationAction.RAISE,
|
||||
model_dir: str = None,
|
||||
model: str,
|
||||
inference_api: Inference,
|
||||
excluded_categories: List[str] = None,
|
||||
disable_input_check: bool = False,
|
||||
disable_output_check: bool = False,
|
||||
on_violation_action: OnViolationAction = OnViolationAction.RAISE,
|
||||
):
|
||||
super().__init__(on_violation_action)
|
||||
|
||||
dtype = torch.bfloat16
|
||||
|
||||
assert model_dir is not None, "Llama Guard model_dir is None"
|
||||
|
||||
if excluded_categories is None:
|
||||
excluded_categories = []
|
||||
|
||||
|
@ -140,18 +117,12 @@ class LlamaGuardShield(ShieldBase):
|
|||
x in SAFETY_CATEGORIES_TO_CODE_MAP.values() for x in excluded_categories
|
||||
), "Invalid categories in excluded categories. Expected format is ['S1', 'S2', ..]"
|
||||
|
||||
self.device = "cuda"
|
||||
self.model = model
|
||||
self.inference_api = inference_api
|
||||
self.excluded_categories = excluded_categories
|
||||
self.disable_input_check = disable_input_check
|
||||
self.disable_output_check = disable_output_check
|
||||
|
||||
# load model
|
||||
torch_dtype = torch.bfloat16
|
||||
self.tokenizer = AutoTokenizer.from_pretrained(model_dir)
|
||||
self.model = AutoModelForCausalLM.from_pretrained(
|
||||
model_dir, torch_dtype=torch_dtype, device_map=self.device
|
||||
)
|
||||
|
||||
def check_unsafe_response(self, response: str) -> Optional[str]:
|
||||
match = re.match(r"^unsafe\n(.*)$", response)
|
||||
if match:
|
||||
|
@ -212,26 +183,21 @@ class LlamaGuardShield(ShieldBase):
|
|||
)
|
||||
else:
|
||||
prompt = self.build_prompt(messages)
|
||||
llama_guard_input = {
|
||||
"role": "user",
|
||||
"content": prompt,
|
||||
}
|
||||
input_ids = self.tokenizer.apply_chat_template(
|
||||
[llama_guard_input], return_tensors="pt", tokenize=True
|
||||
).to(self.device)
|
||||
prompt_len = input_ids.shape[1]
|
||||
output = self.model.generate(
|
||||
input_ids=input_ids,
|
||||
max_new_tokens=20,
|
||||
output_scores=True,
|
||||
return_dict_in_generate=True,
|
||||
pad_token_id=0,
|
||||
)
|
||||
generated_tokens = output.sequences[:, prompt_len:]
|
||||
|
||||
response = self.tokenizer.decode(
|
||||
generated_tokens[0], skip_special_tokens=True
|
||||
)
|
||||
response = response.strip()
|
||||
shield_response = self.get_shield_response(response)
|
||||
# TODO: llama-stack inference protocol has issues with non-streaming inference code
|
||||
content = ""
|
||||
async for chunk in self.inference_api.chat_completion(
|
||||
model=self.model,
|
||||
messages=[
|
||||
UserMessage(content=prompt),
|
||||
],
|
||||
stream=True,
|
||||
):
|
||||
event = chunk.event
|
||||
if event.event_type == ChatCompletionResponseEventType.progress:
|
||||
assert isinstance(event.delta, str)
|
||||
content += event.delta
|
||||
|
||||
content = content.strip()
|
||||
shield_response = self.get_shield_response(content)
|
||||
return shield_response
|
||||
|
|
|
@ -8,11 +8,25 @@ from typing import List
|
|||
|
||||
from llama_stack.distribution.datatypes import * # noqa: F403
|
||||
|
||||
|
||||
EMBEDDING_DEPS = [
|
||||
"blobfile",
|
||||
"chardet",
|
||||
"pypdf",
|
||||
"sentence-transformers",
|
||||
"tqdm",
|
||||
"numpy",
|
||||
"scikit-learn",
|
||||
"scipy",
|
||||
"nltk",
|
||||
"sentencepiece",
|
||||
"transformers",
|
||||
# this happens to work because special dependencies are always installed last
|
||||
# so if there was a regular torch installed first, this would be ignored
|
||||
# we need a better way to do this to identify potential conflicts, etc.
|
||||
# for now, this lets us significantly reduce the size of the container which
|
||||
# does not have any "local" inference code (and hence does not need GPU-enabled torch)
|
||||
"torch --index-url https://download.pytorch.org/whl/cpu",
|
||||
"sentence-transformers --no-deps",
|
||||
]
|
||||
|
||||
|
||||
|
|
|
@ -15,13 +15,15 @@ def available_providers() -> List[ProviderSpec]:
|
|||
api=Api.safety,
|
||||
provider_id="meta-reference",
|
||||
pip_packages=[
|
||||
"accelerate",
|
||||
"codeshield",
|
||||
"torch",
|
||||
"transformers",
|
||||
"torch --index-url https://download.pytorch.org/whl/cpu",
|
||||
],
|
||||
module="llama_stack.providers.impls.meta_reference.safety",
|
||||
config_class="llama_stack.providers.impls.meta_reference.safety.SafetyConfig",
|
||||
api_dependencies=[
|
||||
Api.inference,
|
||||
],
|
||||
),
|
||||
remote_provider_spec(
|
||||
api=Api.safety,
|
||||
|
|
|
@ -25,20 +25,22 @@ from llama_stack.apis.memory import * # noqa: F403
|
|||
|
||||
ALL_MINILM_L6_V2_DIMENSION = 384
|
||||
|
||||
EMBEDDING_MODEL = None
|
||||
EMBEDDING_MODELS = {}
|
||||
|
||||
|
||||
def get_embedding_model() -> "SentenceTransformer":
|
||||
global EMBEDDING_MODEL
|
||||
def get_embedding_model(model: str) -> "SentenceTransformer":
|
||||
global EMBEDDING_MODELS
|
||||
|
||||
if EMBEDDING_MODEL is None:
|
||||
print("Loading sentence transformer")
|
||||
loaded_model = EMBEDDING_MODELS.get(model)
|
||||
if loaded_model is not None:
|
||||
return loaded_model
|
||||
|
||||
from sentence_transformers import SentenceTransformer
|
||||
print(f"Loading sentence transformer for {model}...")
|
||||
from sentence_transformers import SentenceTransformer
|
||||
|
||||
EMBEDDING_MODEL = SentenceTransformer("all-MiniLM-L6-v2")
|
||||
|
||||
return EMBEDDING_MODEL
|
||||
loaded_model = SentenceTransformer(model)
|
||||
EMBEDDING_MODELS[model] = loaded_model
|
||||
return loaded_model
|
||||
|
||||
|
||||
def parse_data_url(data_url: str):
|
||||
|
@ -151,7 +153,7 @@ class BankWithIndex:
|
|||
self,
|
||||
documents: List[MemoryBankDocument],
|
||||
) -> None:
|
||||
model = get_embedding_model()
|
||||
model = get_embedding_model(self.bank.config.embedding_model)
|
||||
for doc in documents:
|
||||
content = await content_from_doc(doc)
|
||||
chunks = make_overlapped_chunks(
|
||||
|
@ -187,6 +189,6 @@ class BankWithIndex:
|
|||
else:
|
||||
query_str = _process(query)
|
||||
|
||||
model = get_embedding_model()
|
||||
model = get_embedding_model(self.bank.config.embedding_model)
|
||||
query_vector = model.encode([query_str])[0].astype(np.float32)
|
||||
return await self.index.query(query_vector, k)
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue