mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-10-16 14:57:20 +00:00
Merge branch 'main' into eval_task_register
This commit is contained in:
commit
1b7e19d5d0
201 changed files with 1635 additions and 807 deletions
|
@ -39,7 +39,7 @@ class RunShieldResponse(BaseModel):
|
|||
|
||||
|
||||
class ShieldStore(Protocol):
|
||||
def get_shield(self, identifier: str) -> ShieldDef: ...
|
||||
async def get_shield(self, identifier: str) -> ShieldDef: ...
|
||||
|
||||
|
||||
@runtime_checkable
|
||||
|
@ -48,5 +48,5 @@ class Safety(Protocol):
|
|||
|
||||
@webmethod(route="/safety/run_shield")
|
||||
async def run_shield(
|
||||
self, shield_type: str, messages: List[Message], params: Dict[str, Any] = None
|
||||
self, identifier: str, messages: List[Message], params: Dict[str, Any] = None
|
||||
) -> RunShieldResponse: ...
|
||||
|
|
|
@ -46,7 +46,7 @@ class Shields(Protocol):
|
|||
async def list_shields(self) -> List[ShieldDefWithProvider]: ...
|
||||
|
||||
@webmethod(route="/shields/get", method="GET")
|
||||
async def get_shield(self, shield_type: str) -> Optional[ShieldDefWithProvider]: ...
|
||||
async def get_shield(self, identifier: str) -> Optional[ShieldDefWithProvider]: ...
|
||||
|
||||
@webmethod(route="/shields/register", method="POST")
|
||||
async def register_shield(self, shield: ShieldDefWithProvider) -> None: ...
|
||||
|
|
|
@ -12,6 +12,10 @@ import os
|
|||
from functools import lru_cache
|
||||
from pathlib import Path
|
||||
|
||||
from llama_stack.distribution.distribution import get_provider_registry
|
||||
from llama_stack.distribution.utils.dynamic import instantiate_class_type
|
||||
|
||||
|
||||
TEMPLATES_PATH = Path(os.path.relpath(__file__)).parent.parent.parent / "templates"
|
||||
|
||||
|
||||
|
@ -176,6 +180,66 @@ class StackBuild(Subcommand):
|
|||
return
|
||||
self._run_stack_build_command_from_build_config(build_config)
|
||||
|
||||
def _generate_run_config(self, build_config: BuildConfig, build_dir: Path) -> None:
|
||||
"""
|
||||
Generate a run.yaml template file for user to edit from a build.yaml file
|
||||
"""
|
||||
import json
|
||||
|
||||
import yaml
|
||||
from termcolor import cprint
|
||||
|
||||
from llama_stack.distribution.build import ImageType
|
||||
|
||||
apis = list(build_config.distribution_spec.providers.keys())
|
||||
run_config = StackRunConfig(
|
||||
built_at=datetime.now(),
|
||||
docker_image=(
|
||||
build_config.name
|
||||
if build_config.image_type == ImageType.docker.value
|
||||
else None
|
||||
),
|
||||
image_name=build_config.name,
|
||||
conda_env=(
|
||||
build_config.name
|
||||
if build_config.image_type == ImageType.conda.value
|
||||
else None
|
||||
),
|
||||
apis=apis,
|
||||
providers={},
|
||||
)
|
||||
# build providers dict
|
||||
provider_registry = get_provider_registry()
|
||||
for api in apis:
|
||||
run_config.providers[api] = []
|
||||
provider_types = build_config.distribution_spec.providers[api]
|
||||
if isinstance(provider_types, str):
|
||||
provider_types = [provider_types]
|
||||
|
||||
for i, provider_type in enumerate(provider_types):
|
||||
p_spec = Provider(
|
||||
provider_id=f"{provider_type}-{i}",
|
||||
provider_type=provider_type,
|
||||
config={},
|
||||
)
|
||||
config_type = instantiate_class_type(
|
||||
provider_registry[Api(api)][provider_type].config_class
|
||||
)
|
||||
p_spec.config = config_type()
|
||||
run_config.providers[api].append(p_spec)
|
||||
|
||||
os.makedirs(build_dir, exist_ok=True)
|
||||
run_config_file = build_dir / f"{build_config.name}-run.yaml"
|
||||
|
||||
with open(run_config_file, "w") as f:
|
||||
to_write = json.loads(run_config.model_dump_json())
|
||||
f.write(yaml.dump(to_write, sort_keys=False))
|
||||
|
||||
cprint(
|
||||
f"You can now edit {run_config_file} and run `llama stack run {run_config_file}`",
|
||||
color="green",
|
||||
)
|
||||
|
||||
def _run_stack_build_command_from_build_config(
|
||||
self, build_config: BuildConfig
|
||||
) -> None:
|
||||
|
@ -183,48 +247,24 @@ class StackBuild(Subcommand):
|
|||
import os
|
||||
|
||||
import yaml
|
||||
from termcolor import cprint
|
||||
|
||||
from llama_stack.distribution.build import build_image, ImageType
|
||||
from llama_stack.distribution.build import build_image
|
||||
from llama_stack.distribution.utils.config_dirs import DISTRIBS_BASE_DIR
|
||||
from llama_stack.distribution.utils.serialize import EnumEncoder
|
||||
|
||||
# save build.yaml spec for building same distribution again
|
||||
if build_config.image_type == ImageType.docker.value:
|
||||
# docker needs build file to be in the llama-stack repo dir to be able to copy over to the image
|
||||
llama_stack_path = Path(
|
||||
os.path.abspath(__file__)
|
||||
).parent.parent.parent.parent
|
||||
build_dir = llama_stack_path / "tmp/configs/"
|
||||
else:
|
||||
build_dir = DISTRIBS_BASE_DIR / f"llamastack-{build_config.name}"
|
||||
|
||||
build_dir = DISTRIBS_BASE_DIR / f"llamastack-{build_config.name}"
|
||||
os.makedirs(build_dir, exist_ok=True)
|
||||
build_file_path = build_dir / f"{build_config.name}-build.yaml"
|
||||
|
||||
with open(build_file_path, "w") as f:
|
||||
to_write = json.loads(json.dumps(build_config.dict(), cls=EnumEncoder))
|
||||
to_write = json.loads(build_config.model_dump_json())
|
||||
f.write(yaml.dump(to_write, sort_keys=False))
|
||||
|
||||
return_code = build_image(build_config, build_file_path)
|
||||
if return_code != 0:
|
||||
return
|
||||
|
||||
configure_name = (
|
||||
build_config.name
|
||||
if build_config.image_type == "conda"
|
||||
else (f"llamastack-{build_config.name}")
|
||||
)
|
||||
if build_config.image_type == "conda":
|
||||
cprint(
|
||||
f"You can now run `llama stack configure {configure_name}`",
|
||||
color="green",
|
||||
)
|
||||
else:
|
||||
cprint(
|
||||
f"You can now edit your run.yaml file and run `docker run -it -p 5000:5000 {build_config.name}`. See full command in llama-stack/distributions/",
|
||||
color="green",
|
||||
)
|
||||
self._generate_run_config(build_config, build_dir)
|
||||
|
||||
def _run_template_list_cmd(self, args: argparse.Namespace) -> None:
|
||||
import json
|
||||
|
|
|
@ -7,8 +7,6 @@
|
|||
import argparse
|
||||
|
||||
from llama_stack.cli.subcommand import Subcommand
|
||||
from llama_stack.distribution.utils.config_dirs import BUILDS_BASE_DIR
|
||||
from llama_stack.distribution.datatypes import * # noqa: F403
|
||||
|
||||
|
||||
class StackConfigure(Subcommand):
|
||||
|
@ -39,123 +37,10 @@ class StackConfigure(Subcommand):
|
|||
)
|
||||
|
||||
def _run_stack_configure_cmd(self, args: argparse.Namespace) -> None:
|
||||
import json
|
||||
import os
|
||||
import subprocess
|
||||
from pathlib import Path
|
||||
|
||||
import pkg_resources
|
||||
|
||||
import yaml
|
||||
from termcolor import cprint
|
||||
|
||||
from llama_stack.distribution.build import ImageType
|
||||
from llama_stack.distribution.utils.exec import run_with_pty
|
||||
|
||||
docker_image = None
|
||||
|
||||
build_config_file = Path(args.config)
|
||||
if build_config_file.exists():
|
||||
with open(build_config_file, "r") as f:
|
||||
build_config = BuildConfig(**yaml.safe_load(f))
|
||||
self._configure_llama_distribution(build_config, args.output_dir)
|
||||
return
|
||||
|
||||
conda_dir = (
|
||||
Path(os.path.expanduser("~/.conda/envs")) / f"llamastack-{args.config}"
|
||||
)
|
||||
output = subprocess.check_output(["bash", "-c", "conda info --json"])
|
||||
conda_envs = json.loads(output.decode("utf-8"))["envs"]
|
||||
|
||||
for x in conda_envs:
|
||||
if x.endswith(f"/llamastack-{args.config}"):
|
||||
conda_dir = Path(x)
|
||||
break
|
||||
|
||||
build_config_file = Path(conda_dir) / f"{args.config}-build.yaml"
|
||||
if build_config_file.exists():
|
||||
with open(build_config_file, "r") as f:
|
||||
build_config = BuildConfig(**yaml.safe_load(f))
|
||||
|
||||
cprint(f"Using {build_config_file}...", "green")
|
||||
self._configure_llama_distribution(build_config, args.output_dir)
|
||||
return
|
||||
|
||||
docker_image = args.config
|
||||
builds_dir = BUILDS_BASE_DIR / ImageType.docker.value
|
||||
if args.output_dir:
|
||||
builds_dir = Path(output_dir)
|
||||
os.makedirs(builds_dir, exist_ok=True)
|
||||
|
||||
script = pkg_resources.resource_filename(
|
||||
"llama_stack", "distribution/configure_container.sh"
|
||||
)
|
||||
script_args = [script, docker_image, str(builds_dir)]
|
||||
|
||||
return_code = run_with_pty(script_args)
|
||||
if return_code != 0:
|
||||
self.parser.error(
|
||||
f"Failed to configure container {docker_image} with return code {return_code}. Please run `llama stack build` first. "
|
||||
)
|
||||
|
||||
def _configure_llama_distribution(
|
||||
self,
|
||||
build_config: BuildConfig,
|
||||
output_dir: Optional[str] = None,
|
||||
):
|
||||
import json
|
||||
import os
|
||||
from pathlib import Path
|
||||
|
||||
import yaml
|
||||
from termcolor import cprint
|
||||
|
||||
from llama_stack.distribution.configure import (
|
||||
configure_api_providers,
|
||||
parse_and_maybe_upgrade_config,
|
||||
)
|
||||
from llama_stack.distribution.utils.serialize import EnumEncoder
|
||||
|
||||
builds_dir = BUILDS_BASE_DIR / build_config.image_type
|
||||
if output_dir:
|
||||
builds_dir = Path(output_dir)
|
||||
os.makedirs(builds_dir, exist_ok=True)
|
||||
image_name = build_config.name.replace("::", "-")
|
||||
run_config_file = builds_dir / f"{image_name}-run.yaml"
|
||||
|
||||
if run_config_file.exists():
|
||||
cprint(
|
||||
f"Configuration already exists at `{str(run_config_file)}`. Will overwrite...",
|
||||
"yellow",
|
||||
attrs=["bold"],
|
||||
)
|
||||
config_dict = yaml.safe_load(run_config_file.read_text())
|
||||
config = parse_and_maybe_upgrade_config(config_dict)
|
||||
else:
|
||||
config = StackRunConfig(
|
||||
built_at=datetime.now(),
|
||||
image_name=image_name,
|
||||
apis=list(build_config.distribution_spec.providers.keys()),
|
||||
providers={},
|
||||
)
|
||||
|
||||
config = configure_api_providers(config, build_config.distribution_spec)
|
||||
|
||||
config.docker_image = (
|
||||
image_name if build_config.image_type == "docker" else None
|
||||
)
|
||||
config.conda_env = image_name if build_config.image_type == "conda" else None
|
||||
|
||||
with open(run_config_file, "w") as f:
|
||||
to_write = json.loads(json.dumps(config.dict(), cls=EnumEncoder))
|
||||
f.write(yaml.dump(to_write, sort_keys=False))
|
||||
|
||||
cprint(
|
||||
f"> YAML configuration has been written to `{run_config_file}`.",
|
||||
color="blue",
|
||||
)
|
||||
|
||||
cprint(
|
||||
f"You can now run `llama stack run {image_name} --port PORT`",
|
||||
color="green",
|
||||
self.parser.error(
|
||||
"""
|
||||
DEPRECATED! llama stack configure has been deprecated.
|
||||
Please use llama stack run --config <path/to/run.yaml> instead.
|
||||
Please see example run.yaml in /distributions folder.
|
||||
"""
|
||||
)
|
||||
|
|
|
@ -45,7 +45,6 @@ class StackRun(Subcommand):
|
|||
|
||||
import pkg_resources
|
||||
import yaml
|
||||
from termcolor import cprint
|
||||
|
||||
from llama_stack.distribution.build import ImageType
|
||||
from llama_stack.distribution.configure import parse_and_maybe_upgrade_config
|
||||
|
@ -71,14 +70,12 @@ class StackRun(Subcommand):
|
|||
|
||||
if not config_file.exists():
|
||||
self.parser.error(
|
||||
f"File {str(config_file)} does not exist. Please run `llama stack build` and `llama stack configure <name>` to generate a run.yaml file"
|
||||
f"File {str(config_file)} does not exist. Please run `llama stack build` to generate (and optionally edit) a run.yaml file"
|
||||
)
|
||||
return
|
||||
|
||||
cprint(f"Using config `{config_file}`", "green")
|
||||
with open(config_file, "r") as f:
|
||||
config_dict = yaml.safe_load(config_file.read_text())
|
||||
config = parse_and_maybe_upgrade_config(config_dict)
|
||||
config_dict = yaml.safe_load(config_file.read_text())
|
||||
config = parse_and_maybe_upgrade_config(config_dict)
|
||||
|
||||
if config.docker_image:
|
||||
script = pkg_resources.resource_filename(
|
||||
|
|
|
@ -36,7 +36,6 @@ SCRIPT_DIR=$(dirname "$(readlink -f "$0")")
|
|||
REPO_DIR=$(dirname $(dirname "$SCRIPT_DIR"))
|
||||
DOCKER_BINARY=${DOCKER_BINARY:-docker}
|
||||
DOCKER_OPTS=${DOCKER_OPTS:-}
|
||||
REPO_CONFIGS_DIR="$REPO_DIR/tmp/configs"
|
||||
|
||||
TEMP_DIR=$(mktemp -d)
|
||||
|
||||
|
@ -115,8 +114,6 @@ ENTRYPOINT ["python", "-m", "llama_stack.distribution.server.server"]
|
|||
|
||||
EOF
|
||||
|
||||
add_to_docker "ADD tmp/configs/$(basename "$build_file_path") ./llamastack-build.yaml"
|
||||
|
||||
printf "Dockerfile created successfully in $TEMP_DIR/Dockerfile"
|
||||
cat $TEMP_DIR/Dockerfile
|
||||
printf "\n"
|
||||
|
@ -138,7 +135,6 @@ set -x
|
|||
$DOCKER_BINARY build $DOCKER_OPTS -t $image_name -f "$TEMP_DIR/Dockerfile" "$REPO_DIR" $mounts
|
||||
|
||||
# clean up tmp/configs
|
||||
rm -rf $REPO_CONFIGS_DIR
|
||||
set +x
|
||||
|
||||
echo "Success!"
|
||||
|
|
|
@ -154,12 +154,12 @@ class SafetyRouter(Safety):
|
|||
|
||||
async def run_shield(
|
||||
self,
|
||||
shield_type: str,
|
||||
identifier: str,
|
||||
messages: List[Message],
|
||||
params: Dict[str, Any] = None,
|
||||
) -> RunShieldResponse:
|
||||
return await self.routing_table.get_provider_impl(shield_type).run_shield(
|
||||
shield_type=shield_type,
|
||||
return await self.routing_table.get_provider_impl(identifier).run_shield(
|
||||
identifier=identifier,
|
||||
messages=messages,
|
||||
params=params,
|
||||
)
|
||||
|
|
|
@ -182,6 +182,12 @@ class CommonRoutingTableImpl(RoutingTable):
|
|||
objs = await self.dist_registry.get_all()
|
||||
return [obj for obj in objs if obj.type == type]
|
||||
|
||||
async def get_all_with_types(
|
||||
self, types: List[str]
|
||||
) -> List[RoutableObjectWithProvider]:
|
||||
objs = await self.dist_registry.get_all()
|
||||
return [obj for obj in objs if obj.type in types]
|
||||
|
||||
|
||||
class ModelsRoutingTable(CommonRoutingTableImpl, Models):
|
||||
async def list_models(self) -> List[ModelDefWithProvider]:
|
||||
|
@ -198,8 +204,8 @@ class ShieldsRoutingTable(CommonRoutingTableImpl, Shields):
|
|||
async def list_shields(self) -> List[ShieldDef]:
|
||||
return await self.get_all_with_type("shield")
|
||||
|
||||
async def get_shield(self, shield_type: str) -> Optional[ShieldDefWithProvider]:
|
||||
return await self.get_object_by_identifier(shield_type)
|
||||
async def get_shield(self, identifier: str) -> Optional[ShieldDefWithProvider]:
|
||||
return await self.get_object_by_identifier(identifier)
|
||||
|
||||
async def register_shield(self, shield: ShieldDefWithProvider) -> None:
|
||||
await self.register_object(shield)
|
||||
|
@ -207,7 +213,14 @@ class ShieldsRoutingTable(CommonRoutingTableImpl, Shields):
|
|||
|
||||
class MemoryBanksRoutingTable(CommonRoutingTableImpl, MemoryBanks):
|
||||
async def list_memory_banks(self) -> List[MemoryBankDefWithProvider]:
|
||||
return await self.get_all_with_type("memory_bank")
|
||||
return await self.get_all_with_types(
|
||||
[
|
||||
MemoryBankType.vector.value,
|
||||
MemoryBankType.keyvalue.value,
|
||||
MemoryBankType.keyword.value,
|
||||
MemoryBankType.graph.value,
|
||||
]
|
||||
)
|
||||
|
||||
async def get_memory_bank(
|
||||
self, identifier: str
|
||||
|
|
|
@ -209,7 +209,8 @@ async def maybe_await(value):
|
|||
|
||||
async def sse_generator(event_gen):
|
||||
try:
|
||||
async for item in await event_gen:
|
||||
event_gen = await event_gen
|
||||
async for item in event_gen:
|
||||
yield create_sse_event(item)
|
||||
await asyncio.sleep(0.01)
|
||||
except asyncio.CancelledError:
|
||||
|
@ -229,7 +230,6 @@ async def sse_generator(event_gen):
|
|||
|
||||
|
||||
def create_dynamic_typed_route(func: Any, method: str):
|
||||
|
||||
async def endpoint(request: Request, **kwargs):
|
||||
await start_trace(func.__name__)
|
||||
|
||||
|
|
|
@ -1,16 +0,0 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class BedrockSafetyConfig(BaseModel):
|
||||
"""Configuration information for a guardrail that you want to use in the request."""
|
||||
|
||||
aws_profile: str = Field(
|
||||
default="default",
|
||||
description="The profile on the machine having valid aws credentials. This will ensure separation of creation to invocation",
|
||||
)
|
|
@ -145,11 +145,12 @@ Fully-qualified name of the module to import. The module is expected to have:
|
|||
|
||||
class RemoteProviderConfig(BaseModel):
|
||||
host: str = "localhost"
|
||||
port: int
|
||||
port: int = 0
|
||||
protocol: str = "http"
|
||||
|
||||
@property
|
||||
def url(self) -> str:
|
||||
return f"http://{self.host}:{self.port}"
|
||||
return f"{self.protocol}://{self.host}:{self.port}"
|
||||
|
||||
|
||||
@json_schema_type
|
||||
|
|
|
@ -16,7 +16,7 @@ from llama_stack.apis.datasets import * # noqa: F403
|
|||
from autoevals.llm import Factuality
|
||||
from autoevals.ragas import AnswerCorrectness
|
||||
from llama_stack.providers.datatypes import ScoringFunctionsProtocolPrivate
|
||||
from llama_stack.providers.impls.meta_reference.scoring.scoring_fn.common import (
|
||||
from llama_stack.providers.inline.meta_reference.scoring.scoring_fn.common import (
|
||||
aggregate_average,
|
||||
)
|
||||
|
|
@ -4,10 +4,11 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from pydantic import BaseModel
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from llama_stack.providers.utils.kvstore import KVStoreConfig
|
||||
from llama_stack.providers.utils.kvstore.config import SqliteKVStoreConfig
|
||||
|
||||
|
||||
class MetaReferenceAgentsImplConfig(BaseModel):
|
||||
persistence_store: KVStoreConfig
|
||||
persistence_store: KVStoreConfig = Field(default=SqliteKVStoreConfig())
|
|
@ -32,18 +32,18 @@ class ShieldRunnerMixin:
|
|||
self.output_shields = output_shields
|
||||
|
||||
async def run_multiple_shields(
|
||||
self, messages: List[Message], shield_types: List[str]
|
||||
self, messages: List[Message], identifiers: List[str]
|
||||
) -> None:
|
||||
responses = await asyncio.gather(
|
||||
*[
|
||||
self.safety_api.run_shield(
|
||||
shield_type=shield_type,
|
||||
identifier=identifier,
|
||||
messages=messages,
|
||||
)
|
||||
for shield_type in shield_types
|
||||
for identifier in identifiers
|
||||
]
|
||||
)
|
||||
for shield_type, response in zip(shield_types, responses):
|
||||
for identifier, response in zip(identifiers, responses):
|
||||
if not response.violation:
|
||||
continue
|
||||
|
||||
|
@ -52,6 +52,6 @@ class ShieldRunnerMixin:
|
|||
raise SafetyException(violation)
|
||||
elif violation.violation_level == ViolationLevel.WARN:
|
||||
cprint(
|
||||
f"[Warn]{shield_type} raised a warning",
|
||||
f"[Warn]{identifier} raised a warning",
|
||||
color="red",
|
||||
)
|
|
@ -9,7 +9,7 @@ from typing import List
|
|||
from llama_stack.apis.inference import Message
|
||||
from llama_stack.apis.safety import * # noqa: F403
|
||||
|
||||
from llama_stack.providers.impls.meta_reference.agents.safety import ShieldRunnerMixin
|
||||
from llama_stack.providers.inline.meta_reference.agents.safety import ShieldRunnerMixin
|
||||
|
||||
from .builtin import BaseTool
|
||||
|
|
@ -14,6 +14,11 @@ from llama_models.llama3.api.datatypes import * # noqa: F403
|
|||
from llama_stack.apis.inference import * # noqa: F403
|
||||
from llama_stack.providers.datatypes import ModelDef, ModelsProtocolPrivate
|
||||
|
||||
from llama_stack.providers.utils.inference.prompt_adapter import (
|
||||
convert_image_media_to_url,
|
||||
request_has_media,
|
||||
)
|
||||
|
||||
from .config import MetaReferenceInferenceConfig
|
||||
from .generation import Llama
|
||||
from .model_parallel import LlamaModelParallelGenerator
|
||||
|
@ -87,6 +92,7 @@ class MetaReferenceInferenceImpl(Inference, ModelsProtocolPrivate):
|
|||
logprobs=logprobs,
|
||||
)
|
||||
self.check_model(request)
|
||||
request = await request_with_localized_media(request)
|
||||
|
||||
if request.stream:
|
||||
return self._stream_completion(request)
|
||||
|
@ -211,6 +217,7 @@ class MetaReferenceInferenceImpl(Inference, ModelsProtocolPrivate):
|
|||
logprobs=logprobs,
|
||||
)
|
||||
self.check_model(request)
|
||||
request = await request_with_localized_media(request)
|
||||
|
||||
if self.config.create_distributed_process_group:
|
||||
if SEMAPHORE.locked():
|
||||
|
@ -388,3 +395,31 @@ class MetaReferenceInferenceImpl(Inference, ModelsProtocolPrivate):
|
|||
contents: List[InterleavedTextMedia],
|
||||
) -> EmbeddingsResponse:
|
||||
raise NotImplementedError()
|
||||
|
||||
|
||||
async def request_with_localized_media(
|
||||
request: Union[ChatCompletionRequest, CompletionRequest],
|
||||
) -> Union[ChatCompletionRequest, CompletionRequest]:
|
||||
if not request_has_media(request):
|
||||
return request
|
||||
|
||||
async def _convert_single_content(content):
|
||||
if isinstance(content, ImageMedia):
|
||||
url = await convert_image_media_to_url(content, download=True)
|
||||
return ImageMedia(image=URL(uri=url))
|
||||
else:
|
||||
return content
|
||||
|
||||
async def _convert_content(content):
|
||||
if isinstance(content, list):
|
||||
return [await _convert_single_content(c) for c in content]
|
||||
else:
|
||||
return await _convert_single_content(content)
|
||||
|
||||
if isinstance(request, ChatCompletionRequest):
|
||||
for m in request.messages:
|
||||
m.content = await _convert_content(m.content)
|
||||
else:
|
||||
request.content = await _convert_content(request.content)
|
||||
|
||||
return request
|
|
@ -27,7 +27,7 @@ from torchao.quantization.GPTQ import Int8DynActInt4WeightLinear
|
|||
|
||||
from llama_stack.apis.inference import QuantizationType
|
||||
|
||||
from llama_stack.providers.impls.meta_reference.inference.config import (
|
||||
from llama_stack.providers.inline.meta_reference.inference.config import (
|
||||
MetaReferenceQuantizedInferenceConfig,
|
||||
)
|
||||
|
21
llama_stack/providers/inline/meta_reference/memory/config.py
Normal file
21
llama_stack/providers/inline/meta_reference/memory/config.py
Normal file
|
@ -0,0 +1,21 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from llama_models.schema_utils import json_schema_type
|
||||
from pydantic import BaseModel
|
||||
|
||||
from llama_stack.distribution.utils.config_dirs import RUNTIME_BASE_DIR
|
||||
from llama_stack.providers.utils.kvstore.config import (
|
||||
KVStoreConfig,
|
||||
SqliteKVStoreConfig,
|
||||
)
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class FaissImplConfig(BaseModel):
|
||||
kvstore: KVStoreConfig = SqliteKVStoreConfig(
|
||||
db_path=(RUNTIME_BASE_DIR / "faiss_store.db").as_posix()
|
||||
) # Uses SQLite config specific to FAISS storage
|
|
@ -16,6 +16,7 @@ from llama_models.llama3.api.datatypes import * # noqa: F403
|
|||
|
||||
from llama_stack.apis.memory import * # noqa: F403
|
||||
from llama_stack.providers.datatypes import MemoryBanksProtocolPrivate
|
||||
from llama_stack.providers.utils.kvstore import kvstore_impl
|
||||
|
||||
from llama_stack.providers.utils.memory.vector_store import (
|
||||
ALL_MINILM_L6_V2_DIMENSION,
|
||||
|
@ -28,6 +29,8 @@ from .config import FaissImplConfig
|
|||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
MEMORY_BANKS_PREFIX = "memory_banks:"
|
||||
|
||||
|
||||
class FaissIndex(EmbeddingIndex):
|
||||
id_by_index: Dict[int, str]
|
||||
|
@ -69,10 +72,25 @@ class FaissMemoryImpl(Memory, MemoryBanksProtocolPrivate):
|
|||
def __init__(self, config: FaissImplConfig) -> None:
|
||||
self.config = config
|
||||
self.cache = {}
|
||||
self.kvstore = None
|
||||
|
||||
async def initialize(self) -> None: ...
|
||||
async def initialize(self) -> None:
|
||||
self.kvstore = await kvstore_impl(self.config.kvstore)
|
||||
# Load existing banks from kvstore
|
||||
start_key = MEMORY_BANKS_PREFIX
|
||||
end_key = f"{MEMORY_BANKS_PREFIX}\xff"
|
||||
stored_banks = await self.kvstore.range(start_key, end_key)
|
||||
|
||||
async def shutdown(self) -> None: ...
|
||||
for bank_data in stored_banks:
|
||||
bank = VectorMemoryBankDef.model_validate_json(bank_data)
|
||||
index = BankWithIndex(
|
||||
bank=bank, index=FaissIndex(ALL_MINILM_L6_V2_DIMENSION)
|
||||
)
|
||||
self.cache[bank.identifier] = index
|
||||
|
||||
async def shutdown(self) -> None:
|
||||
# Cleanup if needed
|
||||
pass
|
||||
|
||||
async def register_memory_bank(
|
||||
self,
|
||||
|
@ -82,6 +100,14 @@ class FaissMemoryImpl(Memory, MemoryBanksProtocolPrivate):
|
|||
memory_bank.type == MemoryBankType.vector.value
|
||||
), f"Only vector banks are supported {memory_bank.type}"
|
||||
|
||||
# Store in kvstore
|
||||
key = f"{MEMORY_BANKS_PREFIX}{memory_bank.identifier}"
|
||||
await self.kvstore.set(
|
||||
key=key,
|
||||
value=memory_bank.json(),
|
||||
)
|
||||
|
||||
# Store in cache
|
||||
index = BankWithIndex(
|
||||
bank=memory_bank, index=FaissIndex(ALL_MINILM_L6_V2_DIMENSION)
|
||||
)
|
|
@ -0,0 +1,73 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import tempfile
|
||||
|
||||
import pytest
|
||||
from llama_stack.apis.memory import MemoryBankType, VectorMemoryBankDef
|
||||
from llama_stack.providers.inline.meta_reference.memory.config import FaissImplConfig
|
||||
|
||||
from llama_stack.providers.inline.meta_reference.memory.faiss import FaissMemoryImpl
|
||||
from llama_stack.providers.utils.kvstore.config import SqliteKVStoreConfig
|
||||
|
||||
|
||||
class TestFaissMemoryImpl:
|
||||
@pytest.fixture
|
||||
def faiss_impl(self):
|
||||
# Create a temporary SQLite database file
|
||||
temp_db = tempfile.NamedTemporaryFile(suffix=".db", delete=False)
|
||||
config = FaissImplConfig(kvstore=SqliteKVStoreConfig(db_path=temp_db.name))
|
||||
return FaissMemoryImpl(config)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_initialize(self, faiss_impl):
|
||||
# Test empty initialization
|
||||
await faiss_impl.initialize()
|
||||
assert len(faiss_impl.cache) == 0
|
||||
|
||||
# Test initialization with existing banks
|
||||
bank = VectorMemoryBankDef(
|
||||
identifier="test_bank",
|
||||
type=MemoryBankType.vector.value,
|
||||
embedding_model="all-MiniLM-L6-v2",
|
||||
chunk_size_in_tokens=512,
|
||||
overlap_size_in_tokens=64,
|
||||
)
|
||||
|
||||
# Register a bank and reinitialize to test loading
|
||||
await faiss_impl.register_memory_bank(bank)
|
||||
|
||||
# Create new instance to test initialization with existing data
|
||||
new_impl = FaissMemoryImpl(faiss_impl.config)
|
||||
await new_impl.initialize()
|
||||
|
||||
assert len(new_impl.cache) == 1
|
||||
assert "test_bank" in new_impl.cache
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_register_memory_bank(self, faiss_impl):
|
||||
bank = VectorMemoryBankDef(
|
||||
identifier="test_bank",
|
||||
type=MemoryBankType.vector.value,
|
||||
embedding_model="all-MiniLM-L6-v2",
|
||||
chunk_size_in_tokens=512,
|
||||
overlap_size_in_tokens=64,
|
||||
)
|
||||
|
||||
await faiss_impl.initialize()
|
||||
await faiss_impl.register_memory_bank(bank)
|
||||
|
||||
assert "test_bank" in faiss_impl.cache
|
||||
assert faiss_impl.cache["test_bank"].bank == bank
|
||||
|
||||
# Verify persistence
|
||||
new_impl = FaissMemoryImpl(faiss_impl.config)
|
||||
await new_impl.initialize()
|
||||
assert "test_bank" in new_impl.cache
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
pytest.main([__file__])
|
|
@ -13,15 +13,15 @@ from llama_stack.apis.datasetio import * # noqa: F403
|
|||
from llama_stack.apis.datasets import * # noqa: F403
|
||||
from llama_stack.apis.inference.inference import Inference
|
||||
from llama_stack.providers.datatypes import ScoringFunctionsProtocolPrivate
|
||||
from llama_stack.providers.impls.meta_reference.scoring.scoring_fn.equality_scoring_fn import (
|
||||
from llama_stack.providers.inline.meta_reference.scoring.scoring_fn.equality_scoring_fn import (
|
||||
EqualityScoringFn,
|
||||
)
|
||||
|
||||
from llama_stack.providers.impls.meta_reference.scoring.scoring_fn.llm_as_judge_scoring_fn import (
|
||||
from llama_stack.providers.inline.meta_reference.scoring.scoring_fn.llm_as_judge_scoring_fn import (
|
||||
LlmAsJudgeScoringFn,
|
||||
)
|
||||
|
||||
from llama_stack.providers.impls.meta_reference.scoring.scoring_fn.subset_of_scoring_fn import (
|
||||
from llama_stack.providers.inline.meta_reference.scoring.scoring_fn.subset_of_scoring_fn import (
|
||||
SubsetOfScoringFn,
|
||||
)
|
||||
|
|
@ -4,18 +4,18 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from llama_stack.providers.impls.meta_reference.scoring.scoring_fn.base_scoring_fn import (
|
||||
from llama_stack.providers.inline.meta_reference.scoring.scoring_fn.base_scoring_fn import (
|
||||
BaseScoringFn,
|
||||
)
|
||||
from llama_stack.apis.scoring_functions import * # noqa: F401, F403
|
||||
from llama_stack.apis.scoring import * # noqa: F401, F403
|
||||
from llama_stack.apis.common.type_system import * # noqa: F403
|
||||
|
||||
from llama_stack.providers.impls.meta_reference.scoring.scoring_fn.common import (
|
||||
from llama_stack.providers.inline.meta_reference.scoring.scoring_fn.common import (
|
||||
aggregate_accuracy,
|
||||
)
|
||||
|
||||
from llama_stack.providers.impls.meta_reference.scoring.scoring_fn.fn_defs.equality import (
|
||||
from llama_stack.providers.inline.meta_reference.scoring.scoring_fn.fn_defs.equality import (
|
||||
equality,
|
||||
)
|
||||
|
|
@ -4,7 +4,7 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
from llama_stack.apis.inference.inference import Inference
|
||||
from llama_stack.providers.impls.meta_reference.scoring.scoring_fn.base_scoring_fn import (
|
||||
from llama_stack.providers.inline.meta_reference.scoring.scoring_fn.base_scoring_fn import (
|
||||
BaseScoringFn,
|
||||
)
|
||||
from llama_stack.apis.scoring_functions import * # noqa: F401, F403
|
||||
|
@ -12,10 +12,10 @@ from llama_stack.apis.scoring import * # noqa: F401, F403
|
|||
from llama_stack.apis.common.type_system import * # noqa: F403
|
||||
import re
|
||||
|
||||
from llama_stack.providers.impls.meta_reference.scoring.scoring_fn.common import (
|
||||
from llama_stack.providers.inline.meta_reference.scoring.scoring_fn.common import (
|
||||
aggregate_average,
|
||||
)
|
||||
from llama_stack.providers.impls.meta_reference.scoring.scoring_fn.fn_defs.llm_as_judge_8b_correctness import (
|
||||
from llama_stack.providers.inline.meta_reference.scoring.scoring_fn.fn_defs.llm_as_judge_8b_correctness import (
|
||||
llm_as_judge_8b_correctness,
|
||||
)
|
||||
|
|
@ -4,17 +4,17 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from llama_stack.providers.impls.meta_reference.scoring.scoring_fn.base_scoring_fn import (
|
||||
from llama_stack.providers.inline.meta_reference.scoring.scoring_fn.base_scoring_fn import (
|
||||
BaseScoringFn,
|
||||
)
|
||||
from llama_stack.apis.scoring_functions import * # noqa: F401, F403
|
||||
from llama_stack.apis.scoring import * # noqa: F401, F403
|
||||
from llama_stack.apis.common.type_system import * # noqa: F403
|
||||
from llama_stack.providers.impls.meta_reference.scoring.scoring_fn.common import (
|
||||
from llama_stack.providers.inline.meta_reference.scoring.scoring_fn.common import (
|
||||
aggregate_accuracy,
|
||||
)
|
||||
|
||||
from llama_stack.providers.impls.meta_reference.scoring.scoring_fn.fn_defs.subset_of import (
|
||||
from llama_stack.providers.inline.meta_reference.scoring.scoring_fn.fn_defs.subset_of import (
|
||||
subset_of,
|
||||
)
|
||||
|
Some files were not shown because too many files have changed in this diff Show more
Loading…
Add table
Add a link
Reference in a new issue