diff --git a/llama_stack/distribution/routers/routers.py b/llama_stack/distribution/routers/routers.py index deee16ae8..eb9aaa540 100644 --- a/llama_stack/distribution/routers/routers.py +++ b/llama_stack/distribution/routers/routers.py @@ -4,7 +4,7 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -from typing import Any, Dict, List, Tuple +from typing import Any, AsyncGenerator, Dict, List, Tuple from llama_stack.distribution.datatypes import Api @@ -24,13 +24,22 @@ class MemoryRouter(Memory): self, routing_table: RoutingTable, ) -> None: + self.api = Api.memory.value self.routing_table = routing_table + self.bank_id_to_type = {} async def initialize(self) -> None: - pass + await self.routing_table.initialize(self.api) async def shutdown(self) -> None: - pass + await self.routing_table.shutdown(self.api) + + def get_provider_from_bank_id(self, bank_id: str) -> Any: + bank_type = self.bank_id_to_type.get(bank_id) + if not bank_type: + raise ValueError(f"Could not find bank type for {bank_id}") + + return self.routing_table.get_provider_impl(self.api, bank_type) async def create_memory_bank( self, @@ -39,9 +48,16 @@ class MemoryRouter(Memory): url: Optional[URL] = None, ) -> MemoryBank: print("MemoryRouter: create_memory_bank") + bank_type = config.type + bank = await self.routing_table.get_provider_impl( + self.api, bank_type + ).create_memory_bank(name, config, url) + self.bank_id_to_type[bank.bank_id] = bank_type + return bank async def get_memory_bank(self, bank_id: str) -> Optional[MemoryBank]: print("MemoryRouter: get_memory_bank") + return await self.get_provider_from_bank_id(bank_id).get_memory_bank(bank_id) async def insert_documents( self, @@ -50,6 +66,9 @@ class MemoryRouter(Memory): ttl_seconds: Optional[int] = None, ) -> None: print("MemoryRouter: insert_documents") + return await self.get_provider_from_bank_id(bank_id).insert_documents( + bank_id, documents, ttl_seconds + ) async def query_documents( self, @@ -57,7 +76,9 @@ class MemoryRouter(Memory): query: InterleavedTextMedia, params: Optional[Dict[str, Any]] = None, ) -> QueryDocumentsResponse: - print("query_documents") + return await self.get_provider_from_bank_id(bank_id).query_documents( + bank_id, query, params + ) class InferenceRouter(Inference): @@ -81,14 +102,13 @@ class InferenceRouter(Inference): model: str, messages: List[Message], sampling_params: Optional[SamplingParams] = SamplingParams(), - # zero-shot tool definitions as input to the model - tools: Optional[List[ToolDefinition]] = list, + tools: Optional[List[ToolDefinition]] = [], tool_choice: Optional[ToolChoice] = ToolChoice.auto, tool_prompt_format: Optional[ToolPromptFormat] = ToolPromptFormat.json, stream: Optional[bool] = False, logprobs: Optional[LogProbConfig] = None, ) -> AsyncGenerator: - # TODO: we need to fix streaming response to align provider implementations with Protocol + # TODO: we need to fix streaming response to align provider implementations with Protocol. async for chunk in self.routing_table.get_provider_impl( self.api, model ).chat_completion( diff --git a/llama_stack/examples/router-table-run.yaml b/llama_stack/examples/router-table-run.yaml index a400011d3..df540674b 100644 --- a/llama_stack/examples/router-table-run.yaml +++ b/llama_stack/examples/router-table-run.yaml @@ -3,40 +3,40 @@ image_name: local docker_image: null conda_env: local apis_to_serve: -- inference -# - memory +# - inference +- memory - telemetry provider_map: telemetry: provider_id: meta-reference config: {} provider_routing_table: - inference: - - routing_key: Meta-Llama3.1-8B-Instruct - provider_id: meta-reference - config: - model: Meta-Llama3.1-8B-Instruct - quantization: null - torch_seed: null - max_seq_len: 4096 - max_batch_size: 1 - # - routing_key: Meta-Llama3.1-8B - # provider_id: meta-reference - # config: - # model: Meta-Llama3.1-8B - # quantization: null - # torch_seed: null - # max_seq_len: 4096 - # max_batch_size: 1 - # memory: - # - routing_key: keyvalue - # provider_id: remote::pgvector - # config: - # host: localhost - # port: 5432 - # db: vectordb - # user: vectoruser - # password: xxxx - # - routing_key: vector + # inference: + # - routing_key: Meta-Llama3.1-8B-Instruct # provider_id: meta-reference - # config: {} + # config: + # model: Meta-Llama3.1-8B-Instruct + # quantization: null + # torch_seed: null + # max_seq_len: 4096 + # max_batch_size: 1 + # - routing_key: Meta-Llama3.1-8B + # provider_id: meta-reference + # config: + # model: Meta-Llama3.1-8B + # quantization: null + # torch_seed: null + # max_seq_len: 4096 + # max_batch_size: 1 + memory: + - routing_key: keyvalue + provider_id: remote::pgvector + config: + host: localhost + port: 5432 + db: vectordb + user: vectoruser + password: xxxx + - routing_key: vector + provider_id: meta-reference + config: {} diff --git a/llama_stack/examples/simple-local-run.yaml b/llama_stack/examples/simple-local-run.yaml new file mode 100644 index 000000000..d4e3d202e --- /dev/null +++ b/llama_stack/examples/simple-local-run.yaml @@ -0,0 +1,38 @@ +built_at: '2024-09-19T22:50:36.239761' +image_name: simple-local +docker_image: null +conda_env: simple-local +apis_to_serve: +- inference +- safety +- agents +- memory +provider_map: + inference: + provider_id: meta-reference + config: + model: Meta-Llama3.1-8B-Instruct + quantization: null + torch_seed: null + max_seq_len: 4096 + max_batch_size: 1 + safety: + provider_id: meta-reference + config: + llama_guard_shield: + model: Llama-Guard-3-8B + excluded_categories: [] + disable_input_check: false + disable_output_check: false + prompt_guard_shield: + model: Prompt-Guard-86M + agents: + provider_id: meta-reference + config: {} + memory: + provider_id: meta-reference + config: {} + telemetry: + provider_id: meta-reference + config: {} +provider_routing_table: {} diff --git a/llama_stack/providers/impls/meta_reference/inference/inference.py b/llama_stack/providers/impls/meta_reference/inference/inference.py index 597a4cb55..8b4d34106 100644 --- a/llama_stack/providers/impls/meta_reference/inference/inference.py +++ b/llama_stack/providers/impls/meta_reference/inference/inference.py @@ -57,7 +57,7 @@ class MetaReferenceInferenceImpl(Inference): model: str, messages: List[Message], sampling_params: Optional[SamplingParams] = SamplingParams(), - tools: Optional[List[ToolDefinition]] = None, + tools: Optional[List[ToolDefinition]] = [], tool_choice: Optional[ToolChoice] = ToolChoice.auto, tool_prompt_format: Optional[ToolPromptFormat] = ToolPromptFormat.json, stream: Optional[bool] = False, @@ -70,7 +70,7 @@ class MetaReferenceInferenceImpl(Inference): model=model, messages=messages, sampling_params=sampling_params, - tools=tools or [], + tools=tools, tool_choice=tool_choice, tool_prompt_format=tool_prompt_format, stream=stream,