From 20a43028770a4ecec6b54c8b5338718800fb54a7 Mon Sep 17 00:00:00 2001 From: Xi Yan Date: Sat, 21 Sep 2024 17:27:19 -0700 Subject: [PATCH] models API --- llama_stack/apis/models/models.py | 44 ++++++++++++-- llama_stack/distribution/routers/routers.py | 3 - llama_stack/examples/router-table-run.yaml | 67 +++++++++++++-------- 3 files changed, 81 insertions(+), 33 deletions(-) diff --git a/llama_stack/apis/models/models.py b/llama_stack/apis/models/models.py index ee1d5f0ba..6917617bc 100644 --- a/llama_stack/apis/models/models.py +++ b/llama_stack/apis/models/models.py @@ -4,11 +4,47 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -from typing import Protocol +from typing import Any, Dict, List, Optional, Protocol -from llama_models.schema_utils import webmethod # noqa: F401 +from llama_models.llama3.api.datatypes import Model -from pydantic import BaseModel # noqa: F401 +from llama_models.schema_utils import json_schema_type, webmethod +from llama_stack.distribution.datatypes import GenericProviderConfig +from pydantic import BaseModel, Field -class Models(Protocol): ... +@json_schema_type +class ModelServingSpec(BaseModel): + llama_model: Model = Field( + description="All metadatas associated with llama model (defined in llama_models.models.sku_list).", + ) + provider_config: GenericProviderConfig = Field( + description="Provider config for the model, including provider_id, and corresponding config. ", + ) + api: str = Field( + description="The API that this model is serving (e.g. inference / safety).", + default="inference", + ) + + +@json_schema_type +class ModelsListResponse(BaseModel): + models_list: List[ModelSpec] + + +@json_schema_type +class ModelsGetResponse(BaseModel): + core_model_spec: Optional[ModelSpec] = None + + +@json_schema_type +class ModelsRegisterResponse(BaseModel): + core_model_spec: Optional[ModelSpec] = None + + +class Models(Protocol): + @webmethod(route="/models/list", method="GET") + async def list_models(self) -> ModelsListResponse: ... + + @webmethod(route="/models/get", method="POST") + async def get_model(self, core_model_id: str) -> ModelsGetResponse: ... diff --git a/llama_stack/distribution/routers/routers.py b/llama_stack/distribution/routers/routers.py index eb9aaa540..fe70cd701 100644 --- a/llama_stack/distribution/routers/routers.py +++ b/llama_stack/distribution/routers/routers.py @@ -47,7 +47,6 @@ class MemoryRouter(Memory): config: MemoryBankConfig, url: Optional[URL] = None, ) -> MemoryBank: - print("MemoryRouter: create_memory_bank") bank_type = config.type bank = await self.routing_table.get_provider_impl( self.api, bank_type @@ -56,7 +55,6 @@ class MemoryRouter(Memory): return bank async def get_memory_bank(self, bank_id: str) -> Optional[MemoryBank]: - print("MemoryRouter: get_memory_bank") return await self.get_provider_from_bank_id(bank_id).get_memory_bank(bank_id) async def insert_documents( @@ -65,7 +63,6 @@ class MemoryRouter(Memory): documents: List[MemoryBankDocument], ttl_seconds: Optional[int] = None, ) -> None: - print("MemoryRouter: insert_documents") return await self.get_provider_from_bank_id(bank_id).insert_documents( bank_id, documents, ttl_seconds ) diff --git a/llama_stack/examples/router-table-run.yaml b/llama_stack/examples/router-table-run.yaml index df540674b..0f83cef06 100644 --- a/llama_stack/examples/router-table-run.yaml +++ b/llama_stack/examples/router-table-run.yaml @@ -3,40 +3,55 @@ image_name: local docker_image: null conda_env: local apis_to_serve: -# - inference +- inference - memory - telemetry +- agents +- safety provider_map: telemetry: provider_id: meta-reference config: {} + safety: + provider_id: meta-reference + config: + llama_guard_shield: + model: Llama-Guard-3-8B + excluded_categories: [] + disable_input_check: false + disable_output_check: false + prompt_guard_shield: + model: Prompt-Guard-86M + agents: + provider_id: meta-reference + config: {} provider_routing_table: - # inference: - # - routing_key: Meta-Llama3.1-8B-Instruct - # provider_id: meta-reference - # config: - # model: Meta-Llama3.1-8B-Instruct - # quantization: null - # torch_seed: null - # max_seq_len: 4096 - # max_batch_size: 1 - # - routing_key: Meta-Llama3.1-8B - # provider_id: meta-reference - # config: - # model: Meta-Llama3.1-8B - # quantization: null - # torch_seed: null - # max_seq_len: 4096 - # max_batch_size: 1 - memory: - - routing_key: keyvalue - provider_id: remote::pgvector + inference: + - routing_key: Meta-Llama3.1-8B-Instruct + provider_id: meta-reference config: - host: localhost - port: 5432 - db: vectordb - user: vectoruser - password: xxxx + model: Meta-Llama3.1-8B-Instruct + quantization: null + torch_seed: null + max_seq_len: 4096 + max_batch_size: 1 + # - routing_key: Meta-Llama3.1-8B + # provider_id: meta-reference + # config: + # model: Meta-Llama3.1-8B + # quantization: null + # torch_seed: null + # max_seq_len: 4096 + # max_batch_size: 1 + memory: + # - routing_key: keyvalue + # provider_id: remote::pgvector + # config: + # host: localhost + # port: 5432 + # db: vectordb + # user: vectoruser + # password: xxxx - routing_key: vector provider_id: meta-reference config: {}