diff --git a/docs/resources/llama-stack-spec.html b/docs/resources/llama-stack-spec.html
index 3d524780a..ce6226f98 100644
--- a/docs/resources/llama-stack-spec.html
+++ b/docs/resources/llama-stack-spec.html
@@ -21,7 +21,7 @@
"info": {
"title": "[DRAFT] Llama Stack Specification",
"version": "0.0.1",
- "description": "This is the specification of the llama stack that provides\n a set of endpoints and their corresponding interfaces that are tailored to\n best leverage Llama Models. The specification is still in draft and subject to change.\n Generated at 2024-11-14 16:18:00.903125"
+ "description": "This is the specification of the llama stack that provides\n a set of endpoints and their corresponding interfaces that are tailored to\n best leverage Llama Models. The specification is still in draft and subject to change.\n Generated at 2024-11-14 17:04:24.301559"
},
"servers": [
{
@@ -2291,75 +2291,6 @@
"required": true
}
}
- },
- "/memory_banks/update": {
- "post": {
- "responses": {},
- "tags": [
- "MemoryBanks"
- ],
- "parameters": [
- {
- "name": "X-LlamaStack-ProviderData",
- "in": "header",
- "description": "JSON-encoded provider data which will be made available to the adapter servicing the API",
- "required": false,
- "schema": {
- "type": "string"
- }
- }
- ],
- "requestBody": {
- "content": {
- "application/json": {
- "schema": {
- "$ref": "#/components/schemas/UpdateMemoryBankRequest"
- }
- }
- },
- "required": true
- }
- }
- },
- "/models/update": {
- "post": {
- "responses": {
- "200": {
- "description": "OK",
- "content": {
- "application/json": {
- "schema": {
- "$ref": "#/components/schemas/Model"
- }
- }
- }
- }
- },
- "tags": [
- "Models"
- ],
- "parameters": [
- {
- "name": "X-LlamaStack-ProviderData",
- "in": "header",
- "description": "JSON-encoded provider data which will be made available to the adapter servicing the API",
- "required": false,
- "schema": {
- "type": "string"
- }
- }
- ],
- "requestBody": {
- "content": {
- "application/json": {
- "schema": {
- "$ref": "#/components/schemas/UpdateModelRequest"
- }
- }
- },
- "required": true
- }
- }
}
},
"jsonSchemaDialect": "https://json-schema.org/draft/2020-12/schema",
@@ -7985,84 +7916,6 @@
"required": [
"model_id"
]
- },
- "UpdateMemoryBankRequest": {
- "type": "object",
- "properties": {
- "memory_bank_id": {
- "type": "string"
- },
- "params": {
- "oneOf": [
- {
- "$ref": "#/components/schemas/VectorMemoryBankParams"
- },
- {
- "$ref": "#/components/schemas/KeyValueMemoryBankParams"
- },
- {
- "$ref": "#/components/schemas/KeywordMemoryBankParams"
- },
- {
- "$ref": "#/components/schemas/GraphMemoryBankParams"
- }
- ]
- },
- "provider_id": {
- "type": "string"
- },
- "provider_memory_bank_id": {
- "type": "string"
- }
- },
- "additionalProperties": false,
- "required": [
- "memory_bank_id",
- "params"
- ]
- },
- "UpdateModelRequest": {
- "type": "object",
- "properties": {
- "model_id": {
- "type": "string"
- },
- "provider_model_id": {
- "type": "string"
- },
- "provider_id": {
- "type": "string"
- },
- "metadata": {
- "type": "object",
- "additionalProperties": {
- "oneOf": [
- {
- "type": "null"
- },
- {
- "type": "boolean"
- },
- {
- "type": "number"
- },
- {
- "type": "string"
- },
- {
- "type": "array"
- },
- {
- "type": "object"
- }
- ]
- }
- }
- },
- "additionalProperties": false,
- "required": [
- "model_id"
- ]
}
},
"responses": {}
@@ -8679,14 +8532,6 @@
"name": "UnstructuredLogEvent",
"description": ""
},
- {
- "name": "UpdateMemoryBankRequest",
- "description": ""
- },
- {
- "name": "UpdateModelRequest",
- "description": ""
- },
{
"name": "UserMessage",
"description": ""
@@ -8873,8 +8718,6 @@
"UnregisterMemoryBankRequest",
"UnregisterModelRequest",
"UnstructuredLogEvent",
- "UpdateMemoryBankRequest",
- "UpdateModelRequest",
"UserMessage",
"VectorMemoryBank",
"VectorMemoryBankParams",
diff --git a/docs/resources/llama-stack-spec.yaml b/docs/resources/llama-stack-spec.yaml
index 8e3187861..a0b3d6c5e 100644
--- a/docs/resources/llama-stack-spec.yaml
+++ b/docs/resources/llama-stack-spec.yaml
@@ -3288,47 +3288,6 @@ components:
- message
- severity
type: object
- UpdateMemoryBankRequest:
- additionalProperties: false
- properties:
- memory_bank_id:
- type: string
- params:
- oneOf:
- - $ref: '#/components/schemas/VectorMemoryBankParams'
- - $ref: '#/components/schemas/KeyValueMemoryBankParams'
- - $ref: '#/components/schemas/KeywordMemoryBankParams'
- - $ref: '#/components/schemas/GraphMemoryBankParams'
- provider_id:
- type: string
- provider_memory_bank_id:
- type: string
- required:
- - memory_bank_id
- - params
- type: object
- UpdateModelRequest:
- additionalProperties: false
- properties:
- metadata:
- additionalProperties:
- oneOf:
- - type: 'null'
- - type: boolean
- - type: number
- - type: string
- - type: array
- - type: object
- type: object
- model_id:
- type: string
- provider_id:
- type: string
- provider_model_id:
- type: string
- required:
- - model_id
- type: object
UserMessage:
additionalProperties: false
properties:
@@ -3441,7 +3400,7 @@ info:
description: "This is the specification of the llama stack that provides\n \
\ a set of endpoints and their corresponding interfaces that are tailored\
\ to\n best leverage Llama Models. The specification is still in\
- \ draft and subject to change.\n Generated at 2024-11-14 16:18:00.903125"
+ \ draft and subject to change.\n Generated at 2024-11-14 17:04:24.301559"
title: '[DRAFT] Llama Stack Specification'
version: 0.0.1
jsonSchemaDialect: https://json-schema.org/draft/2020-12/schema
@@ -4264,25 +4223,6 @@ paths:
description: OK
tags:
- MemoryBanks
- /memory_banks/update:
- post:
- parameters:
- - description: JSON-encoded provider data which will be made available to the
- adapter servicing the API
- in: header
- name: X-LlamaStack-ProviderData
- required: false
- schema:
- type: string
- requestBody:
- content:
- application/json:
- schema:
- $ref: '#/components/schemas/UpdateMemoryBankRequest'
- required: true
- responses: {}
- tags:
- - MemoryBanks
/models/get:
get:
parameters:
@@ -4374,31 +4314,6 @@ paths:
description: OK
tags:
- Models
- /models/update:
- post:
- parameters:
- - description: JSON-encoded provider data which will be made available to the
- adapter servicing the API
- in: header
- name: X-LlamaStack-ProviderData
- required: false
- schema:
- type: string
- requestBody:
- content:
- application/json:
- schema:
- $ref: '#/components/schemas/UpdateModelRequest'
- required: true
- responses:
- '200':
- content:
- application/json:
- schema:
- $ref: '#/components/schemas/Model'
- description: OK
- tags:
- - Models
/post_training/job/artifacts:
get:
parameters:
@@ -5330,12 +5245,6 @@ tags:
- description:
name: UnstructuredLogEvent
-- description:
- name: UpdateMemoryBankRequest
-- description:
- name: UpdateModelRequest
- description:
name: UserMessage
- description: MemoryBank: ...
- @webmethod(route="/memory_banks/update", method="POST")
- async def update_memory_bank(
- self,
- memory_bank_id: str,
- params: BankParams,
- provider_id: Optional[str] = None,
- provider_memory_bank_id: Optional[str] = None,
- ) -> MemoryBank: ...
-
@webmethod(route="/memory_banks/unregister", method="POST")
async def unregister_memory_bank(self, memory_bank_id: str) -> None: ...
diff --git a/llama_stack/apis/models/client.py b/llama_stack/apis/models/client.py
index 3f4f683b3..34541b96e 100644
--- a/llama_stack/apis/models/client.py
+++ b/llama_stack/apis/models/client.py
@@ -7,7 +7,7 @@
import asyncio
import json
-from typing import Any, Dict, List, Optional
+from typing import List, Optional
import fire
import httpx
@@ -61,27 +61,6 @@ class ModelsClient(Models):
return None
return Model(**j)
- async def update_model(
- self,
- model_id: str,
- provider_model_id: Optional[str] = None,
- provider_id: Optional[str] = None,
- metadata: Optional[Dict[str, Any]] = None,
- ) -> Model:
- async with httpx.AsyncClient() as client:
- response = await client.put(
- f"{self.base_url}/models/update",
- json={
- "model_id": model_id,
- "provider_model_id": provider_model_id,
- "provider_id": provider_id,
- "metadata": metadata,
- },
- headers={"Content-Type": "application/json"},
- )
- response.raise_for_status()
- return Model(**response.json())
-
async def unregister_model(self, model_id: str) -> None:
async with httpx.AsyncClient() as client:
response = await client.delete(
diff --git a/llama_stack/apis/models/models.py b/llama_stack/apis/models/models.py
index 19dae12a0..a1bfcac00 100644
--- a/llama_stack/apis/models/models.py
+++ b/llama_stack/apis/models/models.py
@@ -55,14 +55,5 @@ class Models(Protocol):
metadata: Optional[Dict[str, Any]] = None,
) -> Model: ...
- @webmethod(route="/models/update", method="POST")
- async def update_model(
- self,
- model_id: str,
- provider_model_id: Optional[str] = None,
- provider_id: Optional[str] = None,
- metadata: Optional[Dict[str, Any]] = None,
- ) -> Model: ...
-
@webmethod(route="/models/unregister", method="POST")
async def unregister_model(self, model_id: str) -> None: ...
diff --git a/llama_stack/distribution/routers/routing_tables.py b/llama_stack/distribution/routers/routing_tables.py
index d196d3557..76078e652 100644
--- a/llama_stack/distribution/routers/routing_tables.py
+++ b/llama_stack/distribution/routers/routing_tables.py
@@ -51,18 +51,6 @@ async def register_object_with_provider(obj: RoutableObject, p: Any) -> Routable
raise ValueError(f"Unknown API {api} for registering object with provider")
-async def update_object_with_provider(
- obj: RoutableObject, p: Any
-) -> Optional[RoutableObject]:
- api = get_impl_api(p)
- if api == Api.memory:
- return await p.update_memory_bank(obj)
- elif api == Api.inference:
- return await p.update_model(obj)
- else:
- raise ValueError(f"Update not supported for {api}")
-
-
async def unregister_object_from_provider(obj: RoutableObject, p: Any) -> None:
api = get_impl_api(p)
if api == Api.memory:
@@ -176,14 +164,6 @@ class CommonRoutingTableImpl(RoutingTable):
obj, self.impls_by_provider_id[obj.provider_id]
)
- async def update_object(
- self, obj: RoutableObjectWithProvider
- ) -> RoutableObjectWithProvider:
- registered_obj = await update_object_with_provider(
- obj, self.impls_by_provider_id[obj.provider_id]
- )
- return await self.dist_registry.update(registered_obj or obj)
-
async def register_object(
self, obj: RoutableObjectWithProvider
) -> RoutableObjectWithProvider:
@@ -256,27 +236,6 @@ class ModelsRoutingTable(CommonRoutingTableImpl, Models):
registered_model = await self.register_object(model)
return registered_model
- async def update_model(
- self,
- model_id: str,
- provider_model_id: Optional[str] = None,
- provider_id: Optional[str] = None,
- metadata: Optional[Dict[str, Any]] = None,
- ) -> Model:
- existing_model = await self.get_model(model_id)
- if existing_model is None:
- raise ValueError(f"Model {model_id} not found")
-
- updated_model = Model(
- identifier=model_id,
- provider_resource_id=provider_model_id
- or existing_model.provider_resource_id,
- provider_id=provider_id or existing_model.provider_id,
- metadata=metadata or existing_model.metadata,
- )
- registered_model = await self.update_object(updated_model)
- return registered_model
-
async def unregister_model(self, model_id: str) -> None:
existing_model = await self.get_model(model_id)
if existing_model is None:
@@ -357,31 +316,6 @@ class MemoryBanksRoutingTable(CommonRoutingTableImpl, MemoryBanks):
await self.register_object(memory_bank)
return memory_bank
- async def update_memory_bank(
- self,
- memory_bank_id: str,
- params: BankParams,
- provider_id: Optional[str] = None,
- provider_memory_bank_id: Optional[str] = None,
- ) -> MemoryBank:
- existing_bank = await self.get_memory_bank(memory_bank_id)
- if existing_bank is None:
- raise ValueError(f"Memory bank {memory_bank_id} not found")
-
- updated_bank = parse_obj_as(
- MemoryBank,
- {
- "identifier": memory_bank_id,
- "type": ResourceType.memory_bank.value,
- "provider_id": provider_id or existing_bank.provider_id,
- "provider_resource_id": provider_memory_bank_id
- or existing_bank.provider_resource_id,
- **params.model_dump(),
- },
- )
- registered_bank = await self.update_object(updated_bank)
- return registered_bank
-
async def unregister_memory_bank(self, memory_bank_id: str) -> None:
existing_bank = await self.get_memory_bank(memory_bank_id)
if existing_bank is None:
diff --git a/llama_stack/providers/datatypes.py b/llama_stack/providers/datatypes.py
index 1b5d0ebd1..080204e45 100644
--- a/llama_stack/providers/datatypes.py
+++ b/llama_stack/providers/datatypes.py
@@ -45,8 +45,6 @@ class Api(Enum):
class ModelsProtocolPrivate(Protocol):
async def register_model(self, model: Model) -> None: ...
- async def update_model(self, model: Model) -> None: ...
-
async def unregister_model(self, model_id: str) -> None: ...
@@ -61,8 +59,6 @@ class MemoryBanksProtocolPrivate(Protocol):
async def unregister_memory_bank(self, memory_bank_id: str) -> None: ...
- async def update_memory_bank(self, memory_bank: MemoryBank) -> None: ...
-
class DatasetsProtocolPrivate(Protocol):
async def register_dataset(self, dataset: Dataset) -> None: ...
@@ -107,6 +103,7 @@ class RoutingTable(Protocol):
def get_provider_impl(self, routing_key: str) -> Any: ...
+# TODO: this can now be inlined into RemoteProviderSpec
@json_schema_type
class AdapterSpec(BaseModel):
adapter_type: str = Field(
@@ -179,12 +176,10 @@ class RemoteProviderConfig(BaseModel):
@json_schema_type
class RemoteProviderSpec(ProviderSpec):
- adapter: Optional[AdapterSpec] = Field(
- default=None,
+ adapter: AdapterSpec = Field(
description="""
If some code is needed to convert the remote responses into Llama Stack compatible
-API responses, specify the adapter here. If not specified, it indicates the remote
-as being "Llama Stack compatible"
+API responses, specify the adapter here.
""",
)
@@ -194,38 +189,21 @@ as being "Llama Stack compatible"
@property
def module(self) -> str:
- if self.adapter:
- return self.adapter.module
- return "llama_stack.distribution.client"
+ return self.adapter.module
@property
def pip_packages(self) -> List[str]:
- if self.adapter:
- return self.adapter.pip_packages
- return []
+ return self.adapter.pip_packages
@property
def provider_data_validator(self) -> Optional[str]:
- if self.adapter:
- return self.adapter.provider_data_validator
- return None
+ return self.adapter.provider_data_validator
-def is_passthrough(spec: ProviderSpec) -> bool:
- return isinstance(spec, RemoteProviderSpec) and spec.adapter is None
-
-
-# Can avoid this by using Pydantic computed_field
-def remote_provider_spec(
- api: Api, adapter: Optional[AdapterSpec] = None
-) -> RemoteProviderSpec:
- config_class = (
- adapter.config_class
- if adapter and adapter.config_class
- else "llama_stack.distribution.datatypes.RemoteProviderConfig"
- )
- provider_type = f"remote::{adapter.adapter_type}" if adapter else "remote"
-
+def remote_provider_spec(api: Api, adapter: AdapterSpec) -> RemoteProviderSpec:
return RemoteProviderSpec(
- api=api, provider_type=provider_type, config_class=config_class, adapter=adapter
+ api=api,
+ provider_type=f"remote::{adapter.adapter_type}",
+ config_class=adapter.config_class,
+ adapter=adapter,
)
diff --git a/llama_stack/providers/inline/inference/meta_reference/inference.py b/llama_stack/providers/inline/inference/meta_reference/inference.py
index 18128f354..e6bcd6730 100644
--- a/llama_stack/providers/inline/inference/meta_reference/inference.py
+++ b/llama_stack/providers/inline/inference/meta_reference/inference.py
@@ -71,9 +71,6 @@ class MetaReferenceInferenceImpl(Inference, ModelRegistryHelper, ModelsProtocolP
f"Model mismatch: {request.model} != {self.model.descriptor()}"
)
- async def update_model(self, model: Model) -> None:
- pass
-
async def unregister_model(self, model_id: str) -> None:
pass
diff --git a/llama_stack/providers/inline/inference/vllm/vllm.py b/llama_stack/providers/inline/inference/vllm/vllm.py
index e4742240d..0e7ba872c 100644
--- a/llama_stack/providers/inline/inference/vllm/vllm.py
+++ b/llama_stack/providers/inline/inference/vllm/vllm.py
@@ -108,9 +108,6 @@ class VLLMInferenceImpl(Inference, ModelsProtocolPrivate):
return VLLMSamplingParams(**kwargs)
- async def update_model(self, model: Model) -> None:
- pass
-
async def unregister_model(self, model_id: str) -> None:
pass
diff --git a/llama_stack/providers/inline/memory/faiss/faiss.py b/llama_stack/providers/inline/memory/faiss/faiss.py
index 9813f25ce..92235ea89 100644
--- a/llama_stack/providers/inline/memory/faiss/faiss.py
+++ b/llama_stack/providers/inline/memory/faiss/faiss.py
@@ -48,10 +48,9 @@ class FaissIndex(EmbeddingIndex):
self.initialize()
async def initialize(self) -> None:
- if not self.kvstore or not self.bank_id:
+ if not self.kvstore:
return
- # Load existing index data from kvstore
index_key = f"faiss_index:v1::{self.bank_id}"
stored_data = await self.kvstore.get(index_key)
@@ -63,7 +62,6 @@ class FaissIndex(EmbeddingIndex):
for k, v in data["chunk_by_index"].items()
}
- # Load FAISS index
index_bytes = base64.b64decode(data["faiss_index"])
self.index = faiss.deserialize_index(index_bytes)
@@ -71,17 +69,14 @@ class FaissIndex(EmbeddingIndex):
if not self.kvstore or not self.bank_id:
return
- # Serialize FAISS index
index_bytes = faiss.serialize_index(self.index)
- # Prepare data for storage
data = {
"id_by_index": self.id_by_index,
"chunk_by_index": {k: v.json() for k, v in self.chunk_by_index.items()},
"faiss_index": base64.b64encode(index_bytes).decode(),
}
- # Store in kvstore
index_key = f"faiss_index:v1::{self.bank_id}"
await self.kvstore.set(key=index_key, value=json.dumps(data))
@@ -175,15 +170,6 @@ class FaissMemoryImpl(Memory, MemoryBanksProtocolPrivate):
del self.cache[memory_bank_id]
await self.kvstore.delete(f"{MEMORY_BANKS_PREFIX}{memory_bank_id}")
- async def update_memory_bank(self, memory_bank: MemoryBank) -> None:
- # Not possible to update the index in place, so we delete and recreate
- await self.cache[memory_bank.identifier].index.delete()
-
- self.cache[memory_bank.identifier] = BankWithIndex(
- bank=memory_bank,
- index=FaissIndex(ALL_MINILM_L6_V2_DIMENSION, self.kvstore),
- )
-
async def insert_documents(
self,
bank_id: str,
diff --git a/llama_stack/providers/remote/inference/ollama/ollama.py b/llama_stack/providers/remote/inference/ollama/ollama.py
index 208e5036b..3b3f3868b 100644
--- a/llama_stack/providers/remote/inference/ollama/ollama.py
+++ b/llama_stack/providers/remote/inference/ollama/ollama.py
@@ -93,9 +93,6 @@ class OllamaInferenceAdapter(Inference, ModelsProtocolPrivate):
async def shutdown(self) -> None:
pass
- async def update_model(self, model: Model) -> None:
- pass
-
async def unregister_model(self, model_id: str) -> None:
pass
diff --git a/llama_stack/providers/remote/inference/tgi/tgi.py b/llama_stack/providers/remote/inference/tgi/tgi.py
index 0aef9d706..30745cb10 100644
--- a/llama_stack/providers/remote/inference/tgi/tgi.py
+++ b/llama_stack/providers/remote/inference/tgi/tgi.py
@@ -69,9 +69,6 @@ class _HfAdapter(Inference, ModelsProtocolPrivate):
async def shutdown(self) -> None:
pass
- async def update_model(self, model: Model) -> None:
- pass
-
async def unregister_model(self, model_id: str) -> None:
pass
diff --git a/llama_stack/providers/remote/inference/vllm/vllm.py b/llama_stack/providers/remote/inference/vllm/vllm.py
index 4dc8220f1..788f6cac4 100644
--- a/llama_stack/providers/remote/inference/vllm/vllm.py
+++ b/llama_stack/providers/remote/inference/vllm/vllm.py
@@ -58,9 +58,6 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate):
async def shutdown(self) -> None:
pass
- async def update_model(self, model: Model) -> None:
- pass
-
async def unregister_model(self, model_id: str) -> None:
pass
diff --git a/llama_stack/providers/remote/memory/chroma/chroma.py b/llama_stack/providers/remote/memory/chroma/chroma.py
index b74945f03..ac00fc749 100644
--- a/llama_stack/providers/remote/memory/chroma/chroma.py
+++ b/llama_stack/providers/remote/memory/chroma/chroma.py
@@ -141,10 +141,6 @@ class ChromaMemoryAdapter(Memory, MemoryBanksProtocolPrivate):
await self.cache[memory_bank_id].index.delete()
del self.cache[memory_bank_id]
- async def update_memory_bank(self, memory_bank: MemoryBank) -> None:
- await self.unregister_memory_bank(memory_bank.identifier)
- await self.register_memory_bank(memory_bank)
-
async def insert_documents(
self,
bank_id: str,
diff --git a/llama_stack/providers/remote/memory/pgvector/pgvector.py b/llama_stack/providers/remote/memory/pgvector/pgvector.py
index 8395e77b0..44c2a8fe1 100644
--- a/llama_stack/providers/remote/memory/pgvector/pgvector.py
+++ b/llama_stack/providers/remote/memory/pgvector/pgvector.py
@@ -184,10 +184,6 @@ class PGVectorMemoryAdapter(Memory, MemoryBanksProtocolPrivate):
await self.cache[memory_bank_id].index.delete()
del self.cache[memory_bank_id]
- async def update_memory_bank(self, memory_bank: MemoryBank) -> None:
- await self.unregister_memory_bank(memory_bank.identifier)
- await self.register_memory_bank(memory_bank)
-
async def list_memory_banks(self) -> List[MemoryBank]:
banks = load_models(self.cursor, VectorMemoryBank)
for bank in banks: