From aa93eeb2b7d33049c9526e3bbbf3efe70cb35505 Mon Sep 17 00:00:00 2001 From: Dinesh Yeduguru Date: Thu, 14 Nov 2024 17:05:09 -0800 Subject: [PATCH] nuke updates --- docs/resources/llama-stack-spec.html | 159 +----------------- docs/resources/llama-stack-spec.yaml | 95 +---------- llama_stack/apis/memory_banks/memory_banks.py | 9 - llama_stack/apis/models/client.py | 23 +-- llama_stack/apis/models/models.py | 9 - .../distribution/routers/routing_tables.py | 66 -------- llama_stack/providers/datatypes.py | 44 ++--- .../inference/meta_reference/inference.py | 3 - .../providers/inline/inference/vllm/vllm.py | 3 - .../providers/inline/memory/faiss/faiss.py | 16 +- .../remote/inference/ollama/ollama.py | 3 - .../providers/remote/inference/tgi/tgi.py | 3 - .../providers/remote/inference/vllm/vllm.py | 3 - .../providers/remote/memory/chroma/chroma.py | 4 - .../remote/memory/pgvector/pgvector.py | 4 - 15 files changed, 15 insertions(+), 429 deletions(-) diff --git a/docs/resources/llama-stack-spec.html b/docs/resources/llama-stack-spec.html index 3d524780a..ce6226f98 100644 --- a/docs/resources/llama-stack-spec.html +++ b/docs/resources/llama-stack-spec.html @@ -21,7 +21,7 @@ "info": { "title": "[DRAFT] Llama Stack Specification", "version": "0.0.1", - "description": "This is the specification of the llama stack that provides\n a set of endpoints and their corresponding interfaces that are tailored to\n best leverage Llama Models. The specification is still in draft and subject to change.\n Generated at 2024-11-14 16:18:00.903125" + "description": "This is the specification of the llama stack that provides\n a set of endpoints and their corresponding interfaces that are tailored to\n best leverage Llama Models. The specification is still in draft and subject to change.\n Generated at 2024-11-14 17:04:24.301559" }, "servers": [ { @@ -2291,75 +2291,6 @@ "required": true } } - }, - "/memory_banks/update": { - "post": { - "responses": {}, - "tags": [ - "MemoryBanks" - ], - "parameters": [ - { - "name": "X-LlamaStack-ProviderData", - "in": "header", - "description": "JSON-encoded provider data which will be made available to the adapter servicing the API", - "required": false, - "schema": { - "type": "string" - } - } - ], - "requestBody": { - "content": { - "application/json": { - "schema": { - "$ref": "#/components/schemas/UpdateMemoryBankRequest" - } - } - }, - "required": true - } - } - }, - "/models/update": { - "post": { - "responses": { - "200": { - "description": "OK", - "content": { - "application/json": { - "schema": { - "$ref": "#/components/schemas/Model" - } - } - } - } - }, - "tags": [ - "Models" - ], - "parameters": [ - { - "name": "X-LlamaStack-ProviderData", - "in": "header", - "description": "JSON-encoded provider data which will be made available to the adapter servicing the API", - "required": false, - "schema": { - "type": "string" - } - } - ], - "requestBody": { - "content": { - "application/json": { - "schema": { - "$ref": "#/components/schemas/UpdateModelRequest" - } - } - }, - "required": true - } - } } }, "jsonSchemaDialect": "https://json-schema.org/draft/2020-12/schema", @@ -7985,84 +7916,6 @@ "required": [ "model_id" ] - }, - "UpdateMemoryBankRequest": { - "type": "object", - "properties": { - "memory_bank_id": { - "type": "string" - }, - "params": { - "oneOf": [ - { - "$ref": "#/components/schemas/VectorMemoryBankParams" - }, - { - "$ref": "#/components/schemas/KeyValueMemoryBankParams" - }, - { - "$ref": "#/components/schemas/KeywordMemoryBankParams" - }, - { - "$ref": "#/components/schemas/GraphMemoryBankParams" - } - ] - }, - "provider_id": { - "type": "string" - }, - "provider_memory_bank_id": { - "type": "string" - } - }, - "additionalProperties": false, - "required": [ - "memory_bank_id", - "params" - ] - }, - "UpdateModelRequest": { - "type": "object", - "properties": { - "model_id": { - "type": "string" - }, - "provider_model_id": { - "type": "string" - }, - "provider_id": { - "type": "string" - }, - "metadata": { - "type": "object", - "additionalProperties": { - "oneOf": [ - { - "type": "null" - }, - { - "type": "boolean" - }, - { - "type": "number" - }, - { - "type": "string" - }, - { - "type": "array" - }, - { - "type": "object" - } - ] - } - } - }, - "additionalProperties": false, - "required": [ - "model_id" - ] } }, "responses": {} @@ -8679,14 +8532,6 @@ "name": "UnstructuredLogEvent", "description": "" }, - { - "name": "UpdateMemoryBankRequest", - "description": "" - }, - { - "name": "UpdateModelRequest", - "description": "" - }, { "name": "UserMessage", "description": "" @@ -8873,8 +8718,6 @@ "UnregisterMemoryBankRequest", "UnregisterModelRequest", "UnstructuredLogEvent", - "UpdateMemoryBankRequest", - "UpdateModelRequest", "UserMessage", "VectorMemoryBank", "VectorMemoryBankParams", diff --git a/docs/resources/llama-stack-spec.yaml b/docs/resources/llama-stack-spec.yaml index 8e3187861..a0b3d6c5e 100644 --- a/docs/resources/llama-stack-spec.yaml +++ b/docs/resources/llama-stack-spec.yaml @@ -3288,47 +3288,6 @@ components: - message - severity type: object - UpdateMemoryBankRequest: - additionalProperties: false - properties: - memory_bank_id: - type: string - params: - oneOf: - - $ref: '#/components/schemas/VectorMemoryBankParams' - - $ref: '#/components/schemas/KeyValueMemoryBankParams' - - $ref: '#/components/schemas/KeywordMemoryBankParams' - - $ref: '#/components/schemas/GraphMemoryBankParams' - provider_id: - type: string - provider_memory_bank_id: - type: string - required: - - memory_bank_id - - params - type: object - UpdateModelRequest: - additionalProperties: false - properties: - metadata: - additionalProperties: - oneOf: - - type: 'null' - - type: boolean - - type: number - - type: string - - type: array - - type: object - type: object - model_id: - type: string - provider_id: - type: string - provider_model_id: - type: string - required: - - model_id - type: object UserMessage: additionalProperties: false properties: @@ -3441,7 +3400,7 @@ info: description: "This is the specification of the llama stack that provides\n \ \ a set of endpoints and their corresponding interfaces that are tailored\ \ to\n best leverage Llama Models. The specification is still in\ - \ draft and subject to change.\n Generated at 2024-11-14 16:18:00.903125" + \ draft and subject to change.\n Generated at 2024-11-14 17:04:24.301559" title: '[DRAFT] Llama Stack Specification' version: 0.0.1 jsonSchemaDialect: https://json-schema.org/draft/2020-12/schema @@ -4264,25 +4223,6 @@ paths: description: OK tags: - MemoryBanks - /memory_banks/update: - post: - parameters: - - description: JSON-encoded provider data which will be made available to the - adapter servicing the API - in: header - name: X-LlamaStack-ProviderData - required: false - schema: - type: string - requestBody: - content: - application/json: - schema: - $ref: '#/components/schemas/UpdateMemoryBankRequest' - required: true - responses: {} - tags: - - MemoryBanks /models/get: get: parameters: @@ -4374,31 +4314,6 @@ paths: description: OK tags: - Models - /models/update: - post: - parameters: - - description: JSON-encoded provider data which will be made available to the - adapter servicing the API - in: header - name: X-LlamaStack-ProviderData - required: false - schema: - type: string - requestBody: - content: - application/json: - schema: - $ref: '#/components/schemas/UpdateModelRequest' - required: true - responses: - '200': - content: - application/json: - schema: - $ref: '#/components/schemas/Model' - description: OK - tags: - - Models /post_training/job/artifacts: get: parameters: @@ -5330,12 +5245,6 @@ tags: - description: name: UnstructuredLogEvent -- description: - name: UpdateMemoryBankRequest -- description: - name: UpdateModelRequest - description: name: UserMessage - description: MemoryBank: ... - @webmethod(route="/memory_banks/update", method="POST") - async def update_memory_bank( - self, - memory_bank_id: str, - params: BankParams, - provider_id: Optional[str] = None, - provider_memory_bank_id: Optional[str] = None, - ) -> MemoryBank: ... - @webmethod(route="/memory_banks/unregister", method="POST") async def unregister_memory_bank(self, memory_bank_id: str) -> None: ... diff --git a/llama_stack/apis/models/client.py b/llama_stack/apis/models/client.py index 3f4f683b3..34541b96e 100644 --- a/llama_stack/apis/models/client.py +++ b/llama_stack/apis/models/client.py @@ -7,7 +7,7 @@ import asyncio import json -from typing import Any, Dict, List, Optional +from typing import List, Optional import fire import httpx @@ -61,27 +61,6 @@ class ModelsClient(Models): return None return Model(**j) - async def update_model( - self, - model_id: str, - provider_model_id: Optional[str] = None, - provider_id: Optional[str] = None, - metadata: Optional[Dict[str, Any]] = None, - ) -> Model: - async with httpx.AsyncClient() as client: - response = await client.put( - f"{self.base_url}/models/update", - json={ - "model_id": model_id, - "provider_model_id": provider_model_id, - "provider_id": provider_id, - "metadata": metadata, - }, - headers={"Content-Type": "application/json"}, - ) - response.raise_for_status() - return Model(**response.json()) - async def unregister_model(self, model_id: str) -> None: async with httpx.AsyncClient() as client: response = await client.delete( diff --git a/llama_stack/apis/models/models.py b/llama_stack/apis/models/models.py index 19dae12a0..a1bfcac00 100644 --- a/llama_stack/apis/models/models.py +++ b/llama_stack/apis/models/models.py @@ -55,14 +55,5 @@ class Models(Protocol): metadata: Optional[Dict[str, Any]] = None, ) -> Model: ... - @webmethod(route="/models/update", method="POST") - async def update_model( - self, - model_id: str, - provider_model_id: Optional[str] = None, - provider_id: Optional[str] = None, - metadata: Optional[Dict[str, Any]] = None, - ) -> Model: ... - @webmethod(route="/models/unregister", method="POST") async def unregister_model(self, model_id: str) -> None: ... diff --git a/llama_stack/distribution/routers/routing_tables.py b/llama_stack/distribution/routers/routing_tables.py index d196d3557..76078e652 100644 --- a/llama_stack/distribution/routers/routing_tables.py +++ b/llama_stack/distribution/routers/routing_tables.py @@ -51,18 +51,6 @@ async def register_object_with_provider(obj: RoutableObject, p: Any) -> Routable raise ValueError(f"Unknown API {api} for registering object with provider") -async def update_object_with_provider( - obj: RoutableObject, p: Any -) -> Optional[RoutableObject]: - api = get_impl_api(p) - if api == Api.memory: - return await p.update_memory_bank(obj) - elif api == Api.inference: - return await p.update_model(obj) - else: - raise ValueError(f"Update not supported for {api}") - - async def unregister_object_from_provider(obj: RoutableObject, p: Any) -> None: api = get_impl_api(p) if api == Api.memory: @@ -176,14 +164,6 @@ class CommonRoutingTableImpl(RoutingTable): obj, self.impls_by_provider_id[obj.provider_id] ) - async def update_object( - self, obj: RoutableObjectWithProvider - ) -> RoutableObjectWithProvider: - registered_obj = await update_object_with_provider( - obj, self.impls_by_provider_id[obj.provider_id] - ) - return await self.dist_registry.update(registered_obj or obj) - async def register_object( self, obj: RoutableObjectWithProvider ) -> RoutableObjectWithProvider: @@ -256,27 +236,6 @@ class ModelsRoutingTable(CommonRoutingTableImpl, Models): registered_model = await self.register_object(model) return registered_model - async def update_model( - self, - model_id: str, - provider_model_id: Optional[str] = None, - provider_id: Optional[str] = None, - metadata: Optional[Dict[str, Any]] = None, - ) -> Model: - existing_model = await self.get_model(model_id) - if existing_model is None: - raise ValueError(f"Model {model_id} not found") - - updated_model = Model( - identifier=model_id, - provider_resource_id=provider_model_id - or existing_model.provider_resource_id, - provider_id=provider_id or existing_model.provider_id, - metadata=metadata or existing_model.metadata, - ) - registered_model = await self.update_object(updated_model) - return registered_model - async def unregister_model(self, model_id: str) -> None: existing_model = await self.get_model(model_id) if existing_model is None: @@ -357,31 +316,6 @@ class MemoryBanksRoutingTable(CommonRoutingTableImpl, MemoryBanks): await self.register_object(memory_bank) return memory_bank - async def update_memory_bank( - self, - memory_bank_id: str, - params: BankParams, - provider_id: Optional[str] = None, - provider_memory_bank_id: Optional[str] = None, - ) -> MemoryBank: - existing_bank = await self.get_memory_bank(memory_bank_id) - if existing_bank is None: - raise ValueError(f"Memory bank {memory_bank_id} not found") - - updated_bank = parse_obj_as( - MemoryBank, - { - "identifier": memory_bank_id, - "type": ResourceType.memory_bank.value, - "provider_id": provider_id or existing_bank.provider_id, - "provider_resource_id": provider_memory_bank_id - or existing_bank.provider_resource_id, - **params.model_dump(), - }, - ) - registered_bank = await self.update_object(updated_bank) - return registered_bank - async def unregister_memory_bank(self, memory_bank_id: str) -> None: existing_bank = await self.get_memory_bank(memory_bank_id) if existing_bank is None: diff --git a/llama_stack/providers/datatypes.py b/llama_stack/providers/datatypes.py index 1b5d0ebd1..080204e45 100644 --- a/llama_stack/providers/datatypes.py +++ b/llama_stack/providers/datatypes.py @@ -45,8 +45,6 @@ class Api(Enum): class ModelsProtocolPrivate(Protocol): async def register_model(self, model: Model) -> None: ... - async def update_model(self, model: Model) -> None: ... - async def unregister_model(self, model_id: str) -> None: ... @@ -61,8 +59,6 @@ class MemoryBanksProtocolPrivate(Protocol): async def unregister_memory_bank(self, memory_bank_id: str) -> None: ... - async def update_memory_bank(self, memory_bank: MemoryBank) -> None: ... - class DatasetsProtocolPrivate(Protocol): async def register_dataset(self, dataset: Dataset) -> None: ... @@ -107,6 +103,7 @@ class RoutingTable(Protocol): def get_provider_impl(self, routing_key: str) -> Any: ... +# TODO: this can now be inlined into RemoteProviderSpec @json_schema_type class AdapterSpec(BaseModel): adapter_type: str = Field( @@ -179,12 +176,10 @@ class RemoteProviderConfig(BaseModel): @json_schema_type class RemoteProviderSpec(ProviderSpec): - adapter: Optional[AdapterSpec] = Field( - default=None, + adapter: AdapterSpec = Field( description=""" If some code is needed to convert the remote responses into Llama Stack compatible -API responses, specify the adapter here. If not specified, it indicates the remote -as being "Llama Stack compatible" +API responses, specify the adapter here. """, ) @@ -194,38 +189,21 @@ as being "Llama Stack compatible" @property def module(self) -> str: - if self.adapter: - return self.adapter.module - return "llama_stack.distribution.client" + return self.adapter.module @property def pip_packages(self) -> List[str]: - if self.adapter: - return self.adapter.pip_packages - return [] + return self.adapter.pip_packages @property def provider_data_validator(self) -> Optional[str]: - if self.adapter: - return self.adapter.provider_data_validator - return None + return self.adapter.provider_data_validator -def is_passthrough(spec: ProviderSpec) -> bool: - return isinstance(spec, RemoteProviderSpec) and spec.adapter is None - - -# Can avoid this by using Pydantic computed_field -def remote_provider_spec( - api: Api, adapter: Optional[AdapterSpec] = None -) -> RemoteProviderSpec: - config_class = ( - adapter.config_class - if adapter and adapter.config_class - else "llama_stack.distribution.datatypes.RemoteProviderConfig" - ) - provider_type = f"remote::{adapter.adapter_type}" if adapter else "remote" - +def remote_provider_spec(api: Api, adapter: AdapterSpec) -> RemoteProviderSpec: return RemoteProviderSpec( - api=api, provider_type=provider_type, config_class=config_class, adapter=adapter + api=api, + provider_type=f"remote::{adapter.adapter_type}", + config_class=adapter.config_class, + adapter=adapter, ) diff --git a/llama_stack/providers/inline/inference/meta_reference/inference.py b/llama_stack/providers/inline/inference/meta_reference/inference.py index 18128f354..e6bcd6730 100644 --- a/llama_stack/providers/inline/inference/meta_reference/inference.py +++ b/llama_stack/providers/inline/inference/meta_reference/inference.py @@ -71,9 +71,6 @@ class MetaReferenceInferenceImpl(Inference, ModelRegistryHelper, ModelsProtocolP f"Model mismatch: {request.model} != {self.model.descriptor()}" ) - async def update_model(self, model: Model) -> None: - pass - async def unregister_model(self, model_id: str) -> None: pass diff --git a/llama_stack/providers/inline/inference/vllm/vllm.py b/llama_stack/providers/inline/inference/vllm/vllm.py index e4742240d..0e7ba872c 100644 --- a/llama_stack/providers/inline/inference/vllm/vllm.py +++ b/llama_stack/providers/inline/inference/vllm/vllm.py @@ -108,9 +108,6 @@ class VLLMInferenceImpl(Inference, ModelsProtocolPrivate): return VLLMSamplingParams(**kwargs) - async def update_model(self, model: Model) -> None: - pass - async def unregister_model(self, model_id: str) -> None: pass diff --git a/llama_stack/providers/inline/memory/faiss/faiss.py b/llama_stack/providers/inline/memory/faiss/faiss.py index 9813f25ce..92235ea89 100644 --- a/llama_stack/providers/inline/memory/faiss/faiss.py +++ b/llama_stack/providers/inline/memory/faiss/faiss.py @@ -48,10 +48,9 @@ class FaissIndex(EmbeddingIndex): self.initialize() async def initialize(self) -> None: - if not self.kvstore or not self.bank_id: + if not self.kvstore: return - # Load existing index data from kvstore index_key = f"faiss_index:v1::{self.bank_id}" stored_data = await self.kvstore.get(index_key) @@ -63,7 +62,6 @@ class FaissIndex(EmbeddingIndex): for k, v in data["chunk_by_index"].items() } - # Load FAISS index index_bytes = base64.b64decode(data["faiss_index"]) self.index = faiss.deserialize_index(index_bytes) @@ -71,17 +69,14 @@ class FaissIndex(EmbeddingIndex): if not self.kvstore or not self.bank_id: return - # Serialize FAISS index index_bytes = faiss.serialize_index(self.index) - # Prepare data for storage data = { "id_by_index": self.id_by_index, "chunk_by_index": {k: v.json() for k, v in self.chunk_by_index.items()}, "faiss_index": base64.b64encode(index_bytes).decode(), } - # Store in kvstore index_key = f"faiss_index:v1::{self.bank_id}" await self.kvstore.set(key=index_key, value=json.dumps(data)) @@ -175,15 +170,6 @@ class FaissMemoryImpl(Memory, MemoryBanksProtocolPrivate): del self.cache[memory_bank_id] await self.kvstore.delete(f"{MEMORY_BANKS_PREFIX}{memory_bank_id}") - async def update_memory_bank(self, memory_bank: MemoryBank) -> None: - # Not possible to update the index in place, so we delete and recreate - await self.cache[memory_bank.identifier].index.delete() - - self.cache[memory_bank.identifier] = BankWithIndex( - bank=memory_bank, - index=FaissIndex(ALL_MINILM_L6_V2_DIMENSION, self.kvstore), - ) - async def insert_documents( self, bank_id: str, diff --git a/llama_stack/providers/remote/inference/ollama/ollama.py b/llama_stack/providers/remote/inference/ollama/ollama.py index 208e5036b..3b3f3868b 100644 --- a/llama_stack/providers/remote/inference/ollama/ollama.py +++ b/llama_stack/providers/remote/inference/ollama/ollama.py @@ -93,9 +93,6 @@ class OllamaInferenceAdapter(Inference, ModelsProtocolPrivate): async def shutdown(self) -> None: pass - async def update_model(self, model: Model) -> None: - pass - async def unregister_model(self, model_id: str) -> None: pass diff --git a/llama_stack/providers/remote/inference/tgi/tgi.py b/llama_stack/providers/remote/inference/tgi/tgi.py index 0aef9d706..30745cb10 100644 --- a/llama_stack/providers/remote/inference/tgi/tgi.py +++ b/llama_stack/providers/remote/inference/tgi/tgi.py @@ -69,9 +69,6 @@ class _HfAdapter(Inference, ModelsProtocolPrivate): async def shutdown(self) -> None: pass - async def update_model(self, model: Model) -> None: - pass - async def unregister_model(self, model_id: str) -> None: pass diff --git a/llama_stack/providers/remote/inference/vllm/vllm.py b/llama_stack/providers/remote/inference/vllm/vllm.py index 4dc8220f1..788f6cac4 100644 --- a/llama_stack/providers/remote/inference/vllm/vllm.py +++ b/llama_stack/providers/remote/inference/vllm/vllm.py @@ -58,9 +58,6 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate): async def shutdown(self) -> None: pass - async def update_model(self, model: Model) -> None: - pass - async def unregister_model(self, model_id: str) -> None: pass diff --git a/llama_stack/providers/remote/memory/chroma/chroma.py b/llama_stack/providers/remote/memory/chroma/chroma.py index b74945f03..ac00fc749 100644 --- a/llama_stack/providers/remote/memory/chroma/chroma.py +++ b/llama_stack/providers/remote/memory/chroma/chroma.py @@ -141,10 +141,6 @@ class ChromaMemoryAdapter(Memory, MemoryBanksProtocolPrivate): await self.cache[memory_bank_id].index.delete() del self.cache[memory_bank_id] - async def update_memory_bank(self, memory_bank: MemoryBank) -> None: - await self.unregister_memory_bank(memory_bank.identifier) - await self.register_memory_bank(memory_bank) - async def insert_documents( self, bank_id: str, diff --git a/llama_stack/providers/remote/memory/pgvector/pgvector.py b/llama_stack/providers/remote/memory/pgvector/pgvector.py index 8395e77b0..44c2a8fe1 100644 --- a/llama_stack/providers/remote/memory/pgvector/pgvector.py +++ b/llama_stack/providers/remote/memory/pgvector/pgvector.py @@ -184,10 +184,6 @@ class PGVectorMemoryAdapter(Memory, MemoryBanksProtocolPrivate): await self.cache[memory_bank_id].index.delete() del self.cache[memory_bank_id] - async def update_memory_bank(self, memory_bank: MemoryBank) -> None: - await self.unregister_memory_bank(memory_bank.identifier) - await self.register_memory_bank(memory_bank) - async def list_memory_banks(self) -> List[MemoryBank]: banks = load_models(self.cursor, VectorMemoryBank) for bank in banks: