mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-07-29 15:23:51 +00:00
parent
665ab1f812
commit
39c27a3d8c
8 changed files with 45 additions and 213 deletions
|
@ -1,5 +0,0 @@
|
||||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
||||||
# All rights reserved.
|
|
||||||
#
|
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
|
||||||
# the root directory of this source tree.
|
|
|
@ -1,17 +0,0 @@
|
||||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
||||||
# All rights reserved.
|
|
||||||
#
|
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
|
||||||
# the root directory of this source tree.
|
|
||||||
|
|
||||||
from typing import Any, List, Tuple
|
|
||||||
|
|
||||||
from llama_stack.distribution.datatypes import Api
|
|
||||||
|
|
||||||
|
|
||||||
async def get_router_impl(inner_impls: List[Tuple[str, Any]], deps: List[Api]):
|
|
||||||
from .memory import MemoryRouterImpl
|
|
||||||
|
|
||||||
impl = MemoryRouterImpl(inner_impls, deps)
|
|
||||||
await impl.initialize()
|
|
||||||
return impl
|
|
|
@ -1,91 +0,0 @@
|
||||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
||||||
# All rights reserved.
|
|
||||||
#
|
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
|
||||||
# the root directory of this source tree.
|
|
||||||
|
|
||||||
from typing import Any, Dict, List, Tuple
|
|
||||||
|
|
||||||
from llama_stack.distribution.datatypes import Api
|
|
||||||
from llama_stack.apis.memory import * # noqa: F403
|
|
||||||
|
|
||||||
|
|
||||||
class MemoryRouterImpl(Memory):
|
|
||||||
"""Routes to an provider based on the memory bank type"""
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
inner_impls: List[Tuple[str, Any]],
|
|
||||||
deps: List[Api],
|
|
||||||
) -> None:
|
|
||||||
self.deps = deps
|
|
||||||
|
|
||||||
bank_types = [v.value for v in MemoryBankType]
|
|
||||||
|
|
||||||
self.providers = {}
|
|
||||||
for routing_key, provider_impl in inner_impls:
|
|
||||||
if routing_key not in bank_types:
|
|
||||||
raise ValueError(
|
|
||||||
f"Unknown routing key `{routing_key}` for memory bank type"
|
|
||||||
)
|
|
||||||
self.providers[routing_key] = provider_impl
|
|
||||||
|
|
||||||
self.bank_id_to_type = {}
|
|
||||||
|
|
||||||
async def initialize(self) -> None:
|
|
||||||
pass
|
|
||||||
|
|
||||||
async def shutdown(self) -> None:
|
|
||||||
for p in self.providers.values():
|
|
||||||
await p.shutdown()
|
|
||||||
|
|
||||||
def get_provider(self, bank_type):
|
|
||||||
if bank_type not in self.providers:
|
|
||||||
raise ValueError(f"Memory bank type {bank_type} not supported")
|
|
||||||
|
|
||||||
return self.providers[bank_type]
|
|
||||||
|
|
||||||
async def create_memory_bank(
|
|
||||||
self,
|
|
||||||
name: str,
|
|
||||||
config: MemoryBankConfig,
|
|
||||||
url: Optional[URL] = None,
|
|
||||||
) -> MemoryBank:
|
|
||||||
provider = self.get_provider(config.type)
|
|
||||||
bank = await provider.create_memory_bank(name, config, url)
|
|
||||||
self.bank_id_to_type[bank.bank_id] = config.type
|
|
||||||
return bank
|
|
||||||
|
|
||||||
async def get_memory_bank(self, bank_id: str) -> Optional[MemoryBank]:
|
|
||||||
bank_type = self.bank_id_to_type.get(bank_id)
|
|
||||||
if not bank_type:
|
|
||||||
raise ValueError(f"Could not find bank type for {bank_id}")
|
|
||||||
|
|
||||||
provider = self.get_provider(bank_type)
|
|
||||||
return await provider.get_memory_bank(bank_id)
|
|
||||||
|
|
||||||
async def insert_documents(
|
|
||||||
self,
|
|
||||||
bank_id: str,
|
|
||||||
documents: List[MemoryBankDocument],
|
|
||||||
ttl_seconds: Optional[int] = None,
|
|
||||||
) -> None:
|
|
||||||
bank_type = self.bank_id_to_type.get(bank_id)
|
|
||||||
if not bank_type:
|
|
||||||
raise ValueError(f"Could not find bank type for {bank_id}")
|
|
||||||
|
|
||||||
provider = self.get_provider(bank_type)
|
|
||||||
return await provider.insert_documents(bank_id, documents, ttl_seconds)
|
|
||||||
|
|
||||||
async def query_documents(
|
|
||||||
self,
|
|
||||||
bank_id: str,
|
|
||||||
query: InterleavedTextMedia,
|
|
||||||
params: Optional[Dict[str, Any]] = None,
|
|
||||||
) -> QueryDocumentsResponse:
|
|
||||||
bank_type = self.bank_id_to_type.get(bank_id)
|
|
||||||
if not bank_type:
|
|
||||||
raise ValueError(f"Could not find bank type for {bank_id}")
|
|
||||||
|
|
||||||
provider = self.get_provider(bank_type)
|
|
||||||
return await provider.query_documents(bank_id, query, params)
|
|
|
@ -291,56 +291,7 @@ async def resolve_impls_with_routing(
|
||||||
stack_run_config: StackRunConfig,
|
stack_run_config: StackRunConfig,
|
||||||
) -> Dict[Api, Any]:
|
) -> Dict[Api, Any]:
|
||||||
|
|
||||||
all_providers = api_providers()
|
raise NotImplementedError("This is not implemented yet")
|
||||||
specs = {}
|
|
||||||
|
|
||||||
for api_str in stack_run_config.apis_to_serve:
|
|
||||||
api = Api(api_str)
|
|
||||||
providers = all_providers[api]
|
|
||||||
|
|
||||||
# check for regular providers without routing
|
|
||||||
if api_str in stack_run_config.provider_map:
|
|
||||||
provider_map_entry = stack_run_config.provider_map[api_str]
|
|
||||||
if provider_map_entry.provider_id not in providers:
|
|
||||||
raise ValueError(
|
|
||||||
f"Unknown provider `{provider_id}` is not available for API `{api}`"
|
|
||||||
)
|
|
||||||
specs[api] = providers[provider_map_entry.provider_id]
|
|
||||||
|
|
||||||
# check for routing table, we need to pass routing table to the router implementation
|
|
||||||
if api_str in stack_run_config.provider_routing_table:
|
|
||||||
router_entry = stack_run_config.provider_routing_table[api_str]
|
|
||||||
inner_specs = []
|
|
||||||
for rt_entry in router_entry:
|
|
||||||
if rt_entry.provider_id not in providers:
|
|
||||||
raise ValueError(
|
|
||||||
f"Unknown provider `{rt_entry.provider_id}` is not available for API `{api}`"
|
|
||||||
)
|
|
||||||
inner_specs.append(providers[rt_entry.provider_id])
|
|
||||||
|
|
||||||
specs[api] = RouterProviderSpec(
|
|
||||||
api=api,
|
|
||||||
module=f"llama_stack.distribution.routers.{api.value.lower()}",
|
|
||||||
api_dependencies=[],
|
|
||||||
inner_specs=inner_specs,
|
|
||||||
)
|
|
||||||
|
|
||||||
sorted_specs = topological_sort(specs.values())
|
|
||||||
|
|
||||||
impls = {}
|
|
||||||
for spec in sorted_specs:
|
|
||||||
api = spec.api
|
|
||||||
deps = {api: impls[api] for api in spec.api_dependencies}
|
|
||||||
if api.value in stack_run_config.provider_map:
|
|
||||||
provider_config = stack_run_config.provider_map[api.value]
|
|
||||||
elif api.value in stack_run_config.provider_routing_table:
|
|
||||||
provider_config = stack_run_config.provider_routing_table[api.value]
|
|
||||||
else:
|
|
||||||
raise ValueError(f"Cannot find provider_config for Api {api.value}")
|
|
||||||
impl = await instantiate_provider(spec, deps, provider_config)
|
|
||||||
impls[api] = impl
|
|
||||||
|
|
||||||
return impls, specs
|
|
||||||
|
|
||||||
|
|
||||||
async def resolve_impls(
|
async def resolve_impls(
|
||||||
|
@ -405,7 +356,6 @@ def main(yaml_config: str, port: int = 5000, disable_ipv6: bool = False):
|
||||||
if config.provider_routing_table is not None:
|
if config.provider_routing_table is not None:
|
||||||
impls, specs = asyncio.run(resolve_impls_with_routing(config))
|
impls, specs = asyncio.run(resolve_impls_with_routing(config))
|
||||||
else:
|
else:
|
||||||
# keeping this for backwards compatibility,could
|
|
||||||
impls, specs = asyncio.run(resolve_impls(config.provider_map))
|
impls, specs = asyncio.run(resolve_impls(config.provider_map))
|
||||||
|
|
||||||
if Api.telemetry in impls:
|
if Api.telemetry in impls:
|
||||||
|
|
|
@ -33,7 +33,6 @@ async def instantiate_provider(
|
||||||
|
|
||||||
assert isinstance(provider_config, GenericProviderConfig)
|
assert isinstance(provider_config, GenericProviderConfig)
|
||||||
config_type = instantiate_class_type(provider_spec.config_class)
|
config_type = instantiate_class_type(provider_spec.config_class)
|
||||||
print("!!!", provider_config)
|
|
||||||
config = config_type(**provider_config.config)
|
config = config_type(**provider_config.config)
|
||||||
args = [config, deps]
|
args = [config, deps]
|
||||||
elif isinstance(provider_spec, RouterProviderSpec):
|
elif isinstance(provider_spec, RouterProviderSpec):
|
||||||
|
|
|
@ -3,27 +3,26 @@ image_name: local
|
||||||
docker_image: null
|
docker_image: null
|
||||||
conda_env: local
|
conda_env: local
|
||||||
apis_to_serve:
|
apis_to_serve:
|
||||||
# - inference
|
- inference
|
||||||
- memory
|
- memory
|
||||||
- telemetry
|
|
||||||
provider_map:
|
provider_map:
|
||||||
telemetry:
|
# use builtin-router as dummy field
|
||||||
provider_id: meta-reference
|
memory: builtin-router
|
||||||
config: {}
|
inference: builtin-router
|
||||||
provider_routing_table:
|
provider_routing_table:
|
||||||
# inference:
|
inference:
|
||||||
# - routing_key: Meta-Llama3.1-8B-Instruct
|
- routing_key: Meta-Llama3.1-8B-Instruct
|
||||||
# provider_id: meta-reference
|
provider_id: meta-reference
|
||||||
# config:
|
config:
|
||||||
# model: Meta-Llama3.1-8B-Instruct
|
model: Meta-Llama3.1-8B-Instruct
|
||||||
# quantization: null
|
quantization: null
|
||||||
# torch_seed: null
|
torch_seed: null
|
||||||
# max_seq_len: 4096
|
max_seq_len: 4096
|
||||||
# max_batch_size: 1
|
max_batch_size: 1
|
||||||
# - routing_key: Meta-Llama3.1-8B
|
- routing_key: Meta-Llama3.1-8B
|
||||||
# provider_id: remote::ollama
|
provider_id: remote::ollama
|
||||||
# config:
|
config:
|
||||||
# url: http:ollama-url-1.com
|
url: http:ollama-url-1.com
|
||||||
memory:
|
memory:
|
||||||
- routing_key: keyvalue
|
- routing_key: keyvalue
|
||||||
provider_id: remote::pgvector
|
provider_id: remote::pgvector
|
||||||
|
@ -32,7 +31,6 @@ provider_routing_table:
|
||||||
port: 5432
|
port: 5432
|
||||||
db: vectordb
|
db: vectordb
|
||||||
user: vectoruser
|
user: vectoruser
|
||||||
password: xxxx
|
|
||||||
- routing_key: vector
|
- routing_key: vector
|
||||||
provider_id: meta-reference
|
provider_id: meta-reference
|
||||||
config: {}
|
config: {}
|
||||||
|
|
|
@ -128,36 +128,35 @@ class PGVectorMemoryAdapter(Memory):
|
||||||
self.cache = {}
|
self.cache = {}
|
||||||
|
|
||||||
async def initialize(self) -> None:
|
async def initialize(self) -> None:
|
||||||
print("Init PGVector!")
|
try:
|
||||||
# try:
|
self.conn = psycopg2.connect(
|
||||||
# self.conn = psycopg2.connect(
|
host=self.config.host,
|
||||||
# host=self.config.host,
|
port=self.config.port,
|
||||||
# port=self.config.port,
|
database=self.config.db,
|
||||||
# database=self.config.db,
|
user=self.config.user,
|
||||||
# user=self.config.user,
|
password=self.config.password,
|
||||||
# password=self.config.password,
|
)
|
||||||
# )
|
self.cursor = self.conn.cursor()
|
||||||
# self.cursor = self.conn.cursor()
|
|
||||||
|
|
||||||
# version = check_extension_version(self.cursor)
|
version = check_extension_version(self.cursor)
|
||||||
# if version:
|
if version:
|
||||||
# print(f"Vector extension version: {version}")
|
print(f"Vector extension version: {version}")
|
||||||
# else:
|
else:
|
||||||
# raise RuntimeError("Vector extension is not installed.")
|
raise RuntimeError("Vector extension is not installed.")
|
||||||
|
|
||||||
# self.cursor.execute(
|
self.cursor.execute(
|
||||||
# """
|
"""
|
||||||
# CREATE TABLE IF NOT EXISTS metadata_store (
|
CREATE TABLE IF NOT EXISTS metadata_store (
|
||||||
# key TEXT PRIMARY KEY,
|
key TEXT PRIMARY KEY,
|
||||||
# data JSONB
|
data JSONB
|
||||||
# )
|
)
|
||||||
# """
|
"""
|
||||||
# )
|
)
|
||||||
# except Exception as e:
|
except Exception as e:
|
||||||
# import traceback
|
import traceback
|
||||||
|
|
||||||
# traceback.print_exc()
|
traceback.print_exc()
|
||||||
# raise RuntimeError("Could not connect to PGVector database server") from e
|
raise RuntimeError("Could not connect to PGVector database server") from e
|
||||||
|
|
||||||
async def shutdown(self) -> None:
|
async def shutdown(self) -> None:
|
||||||
pass
|
pass
|
||||||
|
|
|
@ -68,8 +68,7 @@ class FaissMemoryImpl(Memory):
|
||||||
self.config = config
|
self.config = config
|
||||||
self.cache = {}
|
self.cache = {}
|
||||||
|
|
||||||
async def initialize(self) -> None:
|
async def initialize(self) -> None: ...
|
||||||
print("INIT meta-reference")
|
|
||||||
|
|
||||||
async def shutdown(self) -> None: ...
|
async def shutdown(self) -> None: ...
|
||||||
|
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue