forked from phoenix-oss/llama-stack-mirror
unregister for memory banks and remove update API (#458)
The semantics of an Update on resources is very tricky to reason about especially for memory banks and models. The best way to go forward here is for the user to unregister and register a new resource. We don't have a compelling reason to support update APIs. Tests: pytest -v -s llama_stack/providers/tests/memory/test_memory.py -m "chroma" --env CHROMA_HOST=localhost --env CHROMA_PORT=8000 pytest -v -s llama_stack/providers/tests/memory/test_memory.py -m "pgvector" --env PGVECTOR_DB=postgres --env PGVECTOR_USER=postgres --env PGVECTOR_PASSWORD=mysecretpassword --env PGVECTOR_HOST=0.0.0.0 $CONDA_PREFIX/bin/pytest -v -s -m "ollama" llama_stack/providers/tests/inference/test_model_registration.py --------- Co-authored-by: Dinesh Yeduguru <dineshyv@fb.com>
This commit is contained in:
parent
2eab3b7ed9
commit
0850ad656a
18 changed files with 286 additions and 250 deletions
|
@ -21,7 +21,7 @@
|
||||||
"info": {
|
"info": {
|
||||||
"title": "[DRAFT] Llama Stack Specification",
|
"title": "[DRAFT] Llama Stack Specification",
|
||||||
"version": "0.0.1",
|
"version": "0.0.1",
|
||||||
"description": "This is the specification of the llama stack that provides\n a set of endpoints and their corresponding interfaces that are tailored to\n best leverage Llama Models. The specification is still in draft and subject to change.\n Generated at 2024-11-14 12:51:12.176325"
|
"description": "This is the specification of the llama stack that provides\n a set of endpoints and their corresponding interfaces that are tailored to\n best leverage Llama Models. The specification is still in draft and subject to change.\n Generated at 2024-11-14 17:04:24.301559"
|
||||||
},
|
},
|
||||||
"servers": [
|
"servers": [
|
||||||
{
|
{
|
||||||
|
@ -429,39 +429,6 @@
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
"/models/delete": {
|
|
||||||
"post": {
|
|
||||||
"responses": {
|
|
||||||
"200": {
|
|
||||||
"description": "OK"
|
|
||||||
}
|
|
||||||
},
|
|
||||||
"tags": [
|
|
||||||
"Models"
|
|
||||||
],
|
|
||||||
"parameters": [
|
|
||||||
{
|
|
||||||
"name": "X-LlamaStack-ProviderData",
|
|
||||||
"in": "header",
|
|
||||||
"description": "JSON-encoded provider data which will be made available to the adapter servicing the API",
|
|
||||||
"required": false,
|
|
||||||
"schema": {
|
|
||||||
"type": "string"
|
|
||||||
}
|
|
||||||
}
|
|
||||||
],
|
|
||||||
"requestBody": {
|
|
||||||
"content": {
|
|
||||||
"application/json": {
|
|
||||||
"schema": {
|
|
||||||
"$ref": "#/components/schemas/DeleteModelRequest"
|
|
||||||
}
|
|
||||||
}
|
|
||||||
},
|
|
||||||
"required": true
|
|
||||||
}
|
|
||||||
}
|
|
||||||
},
|
|
||||||
"/inference/embeddings": {
|
"/inference/embeddings": {
|
||||||
"post": {
|
"post": {
|
||||||
"responses": {
|
"responses": {
|
||||||
|
@ -2259,20 +2226,46 @@
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
"/models/update": {
|
"/memory_banks/unregister": {
|
||||||
"post": {
|
"post": {
|
||||||
"responses": {
|
"responses": {
|
||||||
"200": {
|
"200": {
|
||||||
"description": "OK",
|
"description": "OK"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"tags": [
|
||||||
|
"MemoryBanks"
|
||||||
|
],
|
||||||
|
"parameters": [
|
||||||
|
{
|
||||||
|
"name": "X-LlamaStack-ProviderData",
|
||||||
|
"in": "header",
|
||||||
|
"description": "JSON-encoded provider data which will be made available to the adapter servicing the API",
|
||||||
|
"required": false,
|
||||||
|
"schema": {
|
||||||
|
"type": "string"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"requestBody": {
|
||||||
"content": {
|
"content": {
|
||||||
"application/json": {
|
"application/json": {
|
||||||
"schema": {
|
"schema": {
|
||||||
"$ref": "#/components/schemas/Model"
|
"$ref": "#/components/schemas/UnregisterMemoryBankRequest"
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
},
|
||||||
|
"required": true
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
|
"/models/unregister": {
|
||||||
|
"post": {
|
||||||
|
"responses": {
|
||||||
|
"200": {
|
||||||
|
"description": "OK"
|
||||||
|
}
|
||||||
|
},
|
||||||
"tags": [
|
"tags": [
|
||||||
"Models"
|
"Models"
|
||||||
],
|
],
|
||||||
|
@ -2291,7 +2284,7 @@
|
||||||
"content": {
|
"content": {
|
||||||
"application/json": {
|
"application/json": {
|
||||||
"schema": {
|
"schema": {
|
||||||
"$ref": "#/components/schemas/UpdateModelRequest"
|
"$ref": "#/components/schemas/UnregisterModelRequest"
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
|
@ -4622,18 +4615,6 @@
|
||||||
"session_id"
|
"session_id"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
"DeleteModelRequest": {
|
|
||||||
"type": "object",
|
|
||||||
"properties": {
|
|
||||||
"model_id": {
|
|
||||||
"type": "string"
|
|
||||||
}
|
|
||||||
},
|
|
||||||
"additionalProperties": false,
|
|
||||||
"required": [
|
|
||||||
"model_id"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
"EmbeddingsRequest": {
|
"EmbeddingsRequest": {
|
||||||
"type": "object",
|
"type": "object",
|
||||||
"properties": {
|
"properties": {
|
||||||
|
@ -7912,42 +7893,23 @@
|
||||||
],
|
],
|
||||||
"title": "Response from the synthetic data generation. Batch of (prompt, response, score) tuples that pass the threshold."
|
"title": "Response from the synthetic data generation. Batch of (prompt, response, score) tuples that pass the threshold."
|
||||||
},
|
},
|
||||||
"UpdateModelRequest": {
|
"UnregisterMemoryBankRequest": {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"memory_bank_id": {
|
||||||
|
"type": "string"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"additionalProperties": false,
|
||||||
|
"required": [
|
||||||
|
"memory_bank_id"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"UnregisterModelRequest": {
|
||||||
"type": "object",
|
"type": "object",
|
||||||
"properties": {
|
"properties": {
|
||||||
"model_id": {
|
"model_id": {
|
||||||
"type": "string"
|
"type": "string"
|
||||||
},
|
|
||||||
"provider_model_id": {
|
|
||||||
"type": "string"
|
|
||||||
},
|
|
||||||
"provider_id": {
|
|
||||||
"type": "string"
|
|
||||||
},
|
|
||||||
"metadata": {
|
|
||||||
"type": "object",
|
|
||||||
"additionalProperties": {
|
|
||||||
"oneOf": [
|
|
||||||
{
|
|
||||||
"type": "null"
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"type": "boolean"
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"type": "number"
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"type": "string"
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"type": "array"
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"type": "object"
|
|
||||||
}
|
|
||||||
]
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
"additionalProperties": false,
|
"additionalProperties": false,
|
||||||
|
@ -8132,10 +8094,6 @@
|
||||||
"name": "DeleteAgentsSessionRequest",
|
"name": "DeleteAgentsSessionRequest",
|
||||||
"description": "<SchemaDefinition schemaRef=\"#/components/schemas/DeleteAgentsSessionRequest\" />"
|
"description": "<SchemaDefinition schemaRef=\"#/components/schemas/DeleteAgentsSessionRequest\" />"
|
||||||
},
|
},
|
||||||
{
|
|
||||||
"name": "DeleteModelRequest",
|
|
||||||
"description": "<SchemaDefinition schemaRef=\"#/components/schemas/DeleteModelRequest\" />"
|
|
||||||
},
|
|
||||||
{
|
{
|
||||||
"name": "DoraFinetuningConfig",
|
"name": "DoraFinetuningConfig",
|
||||||
"description": "<SchemaDefinition schemaRef=\"#/components/schemas/DoraFinetuningConfig\" />"
|
"description": "<SchemaDefinition schemaRef=\"#/components/schemas/DoraFinetuningConfig\" />"
|
||||||
|
@ -8563,12 +8521,16 @@
|
||||||
"description": "<SchemaDefinition schemaRef=\"#/components/schemas/URL\" />"
|
"description": "<SchemaDefinition schemaRef=\"#/components/schemas/URL\" />"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"name": "UnstructuredLogEvent",
|
"name": "UnregisterMemoryBankRequest",
|
||||||
"description": "<SchemaDefinition schemaRef=\"#/components/schemas/UnstructuredLogEvent\" />"
|
"description": "<SchemaDefinition schemaRef=\"#/components/schemas/UnregisterMemoryBankRequest\" />"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"name": "UpdateModelRequest",
|
"name": "UnregisterModelRequest",
|
||||||
"description": "<SchemaDefinition schemaRef=\"#/components/schemas/UpdateModelRequest\" />"
|
"description": "<SchemaDefinition schemaRef=\"#/components/schemas/UnregisterModelRequest\" />"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "UnstructuredLogEvent",
|
||||||
|
"description": "<SchemaDefinition schemaRef=\"#/components/schemas/UnstructuredLogEvent\" />"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"name": "UserMessage",
|
"name": "UserMessage",
|
||||||
|
@ -8657,7 +8619,6 @@
|
||||||
"Dataset",
|
"Dataset",
|
||||||
"DeleteAgentsRequest",
|
"DeleteAgentsRequest",
|
||||||
"DeleteAgentsSessionRequest",
|
"DeleteAgentsSessionRequest",
|
||||||
"DeleteModelRequest",
|
|
||||||
"DoraFinetuningConfig",
|
"DoraFinetuningConfig",
|
||||||
"EmbeddingsRequest",
|
"EmbeddingsRequest",
|
||||||
"EmbeddingsResponse",
|
"EmbeddingsResponse",
|
||||||
|
@ -8754,8 +8715,9 @@
|
||||||
"TrainingConfig",
|
"TrainingConfig",
|
||||||
"Turn",
|
"Turn",
|
||||||
"URL",
|
"URL",
|
||||||
|
"UnregisterMemoryBankRequest",
|
||||||
|
"UnregisterModelRequest",
|
||||||
"UnstructuredLogEvent",
|
"UnstructuredLogEvent",
|
||||||
"UpdateModelRequest",
|
|
||||||
"UserMessage",
|
"UserMessage",
|
||||||
"VectorMemoryBank",
|
"VectorMemoryBank",
|
||||||
"VectorMemoryBankParams",
|
"VectorMemoryBankParams",
|
||||||
|
|
|
@ -867,14 +867,6 @@ components:
|
||||||
- agent_id
|
- agent_id
|
||||||
- session_id
|
- session_id
|
||||||
type: object
|
type: object
|
||||||
DeleteModelRequest:
|
|
||||||
additionalProperties: false
|
|
||||||
properties:
|
|
||||||
model_id:
|
|
||||||
type: string
|
|
||||||
required:
|
|
||||||
- model_id
|
|
||||||
type: object
|
|
||||||
DoraFinetuningConfig:
|
DoraFinetuningConfig:
|
||||||
additionalProperties: false
|
additionalProperties: false
|
||||||
properties:
|
properties:
|
||||||
|
@ -3244,6 +3236,22 @@ components:
|
||||||
format: uri
|
format: uri
|
||||||
pattern: ^(https?://|file://|data:)
|
pattern: ^(https?://|file://|data:)
|
||||||
type: string
|
type: string
|
||||||
|
UnregisterMemoryBankRequest:
|
||||||
|
additionalProperties: false
|
||||||
|
properties:
|
||||||
|
memory_bank_id:
|
||||||
|
type: string
|
||||||
|
required:
|
||||||
|
- memory_bank_id
|
||||||
|
type: object
|
||||||
|
UnregisterModelRequest:
|
||||||
|
additionalProperties: false
|
||||||
|
properties:
|
||||||
|
model_id:
|
||||||
|
type: string
|
||||||
|
required:
|
||||||
|
- model_id
|
||||||
|
type: object
|
||||||
UnstructuredLogEvent:
|
UnstructuredLogEvent:
|
||||||
additionalProperties: false
|
additionalProperties: false
|
||||||
properties:
|
properties:
|
||||||
|
@ -3280,28 +3288,6 @@ components:
|
||||||
- message
|
- message
|
||||||
- severity
|
- severity
|
||||||
type: object
|
type: object
|
||||||
UpdateModelRequest:
|
|
||||||
additionalProperties: false
|
|
||||||
properties:
|
|
||||||
metadata:
|
|
||||||
additionalProperties:
|
|
||||||
oneOf:
|
|
||||||
- type: 'null'
|
|
||||||
- type: boolean
|
|
||||||
- type: number
|
|
||||||
- type: string
|
|
||||||
- type: array
|
|
||||||
- type: object
|
|
||||||
type: object
|
|
||||||
model_id:
|
|
||||||
type: string
|
|
||||||
provider_id:
|
|
||||||
type: string
|
|
||||||
provider_model_id:
|
|
||||||
type: string
|
|
||||||
required:
|
|
||||||
- model_id
|
|
||||||
type: object
|
|
||||||
UserMessage:
|
UserMessage:
|
||||||
additionalProperties: false
|
additionalProperties: false
|
||||||
properties:
|
properties:
|
||||||
|
@ -3414,7 +3400,7 @@ info:
|
||||||
description: "This is the specification of the llama stack that provides\n \
|
description: "This is the specification of the llama stack that provides\n \
|
||||||
\ a set of endpoints and their corresponding interfaces that are tailored\
|
\ a set of endpoints and their corresponding interfaces that are tailored\
|
||||||
\ to\n best leverage Llama Models. The specification is still in\
|
\ to\n best leverage Llama Models. The specification is still in\
|
||||||
\ draft and subject to change.\n Generated at 2024-11-14 12:51:12.176325"
|
\ draft and subject to change.\n Generated at 2024-11-14 17:04:24.301559"
|
||||||
title: '[DRAFT] Llama Stack Specification'
|
title: '[DRAFT] Llama Stack Specification'
|
||||||
version: 0.0.1
|
version: 0.0.1
|
||||||
jsonSchemaDialect: https://json-schema.org/draft/2020-12/schema
|
jsonSchemaDialect: https://json-schema.org/draft/2020-12/schema
|
||||||
|
@ -4216,7 +4202,7 @@ paths:
|
||||||
responses: {}
|
responses: {}
|
||||||
tags:
|
tags:
|
||||||
- MemoryBanks
|
- MemoryBanks
|
||||||
/models/delete:
|
/memory_banks/unregister:
|
||||||
post:
|
post:
|
||||||
parameters:
|
parameters:
|
||||||
- description: JSON-encoded provider data which will be made available to the
|
- description: JSON-encoded provider data which will be made available to the
|
||||||
|
@ -4230,13 +4216,13 @@ paths:
|
||||||
content:
|
content:
|
||||||
application/json:
|
application/json:
|
||||||
schema:
|
schema:
|
||||||
$ref: '#/components/schemas/DeleteModelRequest'
|
$ref: '#/components/schemas/UnregisterMemoryBankRequest'
|
||||||
required: true
|
required: true
|
||||||
responses:
|
responses:
|
||||||
'200':
|
'200':
|
||||||
description: OK
|
description: OK
|
||||||
tags:
|
tags:
|
||||||
- Models
|
- MemoryBanks
|
||||||
/models/get:
|
/models/get:
|
||||||
get:
|
get:
|
||||||
parameters:
|
parameters:
|
||||||
|
@ -4307,7 +4293,7 @@ paths:
|
||||||
description: OK
|
description: OK
|
||||||
tags:
|
tags:
|
||||||
- Models
|
- Models
|
||||||
/models/update:
|
/models/unregister:
|
||||||
post:
|
post:
|
||||||
parameters:
|
parameters:
|
||||||
- description: JSON-encoded provider data which will be made available to the
|
- description: JSON-encoded provider data which will be made available to the
|
||||||
|
@ -4321,14 +4307,10 @@ paths:
|
||||||
content:
|
content:
|
||||||
application/json:
|
application/json:
|
||||||
schema:
|
schema:
|
||||||
$ref: '#/components/schemas/UpdateModelRequest'
|
$ref: '#/components/schemas/UnregisterModelRequest'
|
||||||
required: true
|
required: true
|
||||||
responses:
|
responses:
|
||||||
'200':
|
'200':
|
||||||
content:
|
|
||||||
application/json:
|
|
||||||
schema:
|
|
||||||
$ref: '#/components/schemas/Model'
|
|
||||||
description: OK
|
description: OK
|
||||||
tags:
|
tags:
|
||||||
- Models
|
- Models
|
||||||
|
@ -4960,9 +4942,6 @@ tags:
|
||||||
- description: <SchemaDefinition schemaRef="#/components/schemas/DeleteAgentsSessionRequest"
|
- description: <SchemaDefinition schemaRef="#/components/schemas/DeleteAgentsSessionRequest"
|
||||||
/>
|
/>
|
||||||
name: DeleteAgentsSessionRequest
|
name: DeleteAgentsSessionRequest
|
||||||
- description: <SchemaDefinition schemaRef="#/components/schemas/DeleteModelRequest"
|
|
||||||
/>
|
|
||||||
name: DeleteModelRequest
|
|
||||||
- description: <SchemaDefinition schemaRef="#/components/schemas/DoraFinetuningConfig"
|
- description: <SchemaDefinition schemaRef="#/components/schemas/DoraFinetuningConfig"
|
||||||
/>
|
/>
|
||||||
name: DoraFinetuningConfig
|
name: DoraFinetuningConfig
|
||||||
|
@ -5257,12 +5236,15 @@ tags:
|
||||||
name: Turn
|
name: Turn
|
||||||
- description: <SchemaDefinition schemaRef="#/components/schemas/URL" />
|
- description: <SchemaDefinition schemaRef="#/components/schemas/URL" />
|
||||||
name: URL
|
name: URL
|
||||||
|
- description: <SchemaDefinition schemaRef="#/components/schemas/UnregisterMemoryBankRequest"
|
||||||
|
/>
|
||||||
|
name: UnregisterMemoryBankRequest
|
||||||
|
- description: <SchemaDefinition schemaRef="#/components/schemas/UnregisterModelRequest"
|
||||||
|
/>
|
||||||
|
name: UnregisterModelRequest
|
||||||
- description: <SchemaDefinition schemaRef="#/components/schemas/UnstructuredLogEvent"
|
- description: <SchemaDefinition schemaRef="#/components/schemas/UnstructuredLogEvent"
|
||||||
/>
|
/>
|
||||||
name: UnstructuredLogEvent
|
name: UnstructuredLogEvent
|
||||||
- description: <SchemaDefinition schemaRef="#/components/schemas/UpdateModelRequest"
|
|
||||||
/>
|
|
||||||
name: UpdateModelRequest
|
|
||||||
- description: <SchemaDefinition schemaRef="#/components/schemas/UserMessage" />
|
- description: <SchemaDefinition schemaRef="#/components/schemas/UserMessage" />
|
||||||
name: UserMessage
|
name: UserMessage
|
||||||
- description: <SchemaDefinition schemaRef="#/components/schemas/VectorMemoryBank"
|
- description: <SchemaDefinition schemaRef="#/components/schemas/VectorMemoryBank"
|
||||||
|
@ -5338,7 +5320,6 @@ x-tagGroups:
|
||||||
- Dataset
|
- Dataset
|
||||||
- DeleteAgentsRequest
|
- DeleteAgentsRequest
|
||||||
- DeleteAgentsSessionRequest
|
- DeleteAgentsSessionRequest
|
||||||
- DeleteModelRequest
|
|
||||||
- DoraFinetuningConfig
|
- DoraFinetuningConfig
|
||||||
- EmbeddingsRequest
|
- EmbeddingsRequest
|
||||||
- EmbeddingsResponse
|
- EmbeddingsResponse
|
||||||
|
@ -5435,8 +5416,9 @@ x-tagGroups:
|
||||||
- TrainingConfig
|
- TrainingConfig
|
||||||
- Turn
|
- Turn
|
||||||
- URL
|
- URL
|
||||||
|
- UnregisterMemoryBankRequest
|
||||||
|
- UnregisterModelRequest
|
||||||
- UnstructuredLogEvent
|
- UnstructuredLogEvent
|
||||||
- UpdateModelRequest
|
|
||||||
- UserMessage
|
- UserMessage
|
||||||
- VectorMemoryBank
|
- VectorMemoryBank
|
||||||
- VectorMemoryBankParams
|
- VectorMemoryBankParams
|
||||||
|
|
|
@ -144,3 +144,6 @@ class MemoryBanks(Protocol):
|
||||||
provider_id: Optional[str] = None,
|
provider_id: Optional[str] = None,
|
||||||
provider_memory_bank_id: Optional[str] = None,
|
provider_memory_bank_id: Optional[str] = None,
|
||||||
) -> MemoryBank: ...
|
) -> MemoryBank: ...
|
||||||
|
|
||||||
|
@webmethod(route="/memory_banks/unregister", method="POST")
|
||||||
|
async def unregister_memory_bank(self, memory_bank_id: str) -> None: ...
|
||||||
|
|
|
@ -7,7 +7,7 @@
|
||||||
import asyncio
|
import asyncio
|
||||||
import json
|
import json
|
||||||
|
|
||||||
from typing import Any, Dict, List, Optional
|
from typing import List, Optional
|
||||||
|
|
||||||
import fire
|
import fire
|
||||||
import httpx
|
import httpx
|
||||||
|
@ -61,28 +61,7 @@ class ModelsClient(Models):
|
||||||
return None
|
return None
|
||||||
return Model(**j)
|
return Model(**j)
|
||||||
|
|
||||||
async def update_model(
|
async def unregister_model(self, model_id: str) -> None:
|
||||||
self,
|
|
||||||
model_id: str,
|
|
||||||
provider_model_id: Optional[str] = None,
|
|
||||||
provider_id: Optional[str] = None,
|
|
||||||
metadata: Optional[Dict[str, Any]] = None,
|
|
||||||
) -> Model:
|
|
||||||
async with httpx.AsyncClient() as client:
|
|
||||||
response = await client.put(
|
|
||||||
f"{self.base_url}/models/update",
|
|
||||||
json={
|
|
||||||
"model_id": model_id,
|
|
||||||
"provider_model_id": provider_model_id,
|
|
||||||
"provider_id": provider_id,
|
|
||||||
"metadata": metadata,
|
|
||||||
},
|
|
||||||
headers={"Content-Type": "application/json"},
|
|
||||||
)
|
|
||||||
response.raise_for_status()
|
|
||||||
return Model(**response.json())
|
|
||||||
|
|
||||||
async def delete_model(self, model_id: str) -> None:
|
|
||||||
async with httpx.AsyncClient() as client:
|
async with httpx.AsyncClient() as client:
|
||||||
response = await client.delete(
|
response = await client.delete(
|
||||||
f"{self.base_url}/models/delete",
|
f"{self.base_url}/models/delete",
|
||||||
|
|
|
@ -55,14 +55,5 @@ class Models(Protocol):
|
||||||
metadata: Optional[Dict[str, Any]] = None,
|
metadata: Optional[Dict[str, Any]] = None,
|
||||||
) -> Model: ...
|
) -> Model: ...
|
||||||
|
|
||||||
@webmethod(route="/models/update", method="POST")
|
@webmethod(route="/models/unregister", method="POST")
|
||||||
async def update_model(
|
async def unregister_model(self, model_id: str) -> None: ...
|
||||||
self,
|
|
||||||
model_id: str,
|
|
||||||
provider_model_id: Optional[str] = None,
|
|
||||||
provider_id: Optional[str] = None,
|
|
||||||
metadata: Optional[Dict[str, Any]] = None,
|
|
||||||
) -> Model: ...
|
|
||||||
|
|
||||||
@webmethod(route="/models/delete", method="POST")
|
|
||||||
async def delete_model(self, model_id: str) -> None: ...
|
|
||||||
|
|
|
@ -51,6 +51,16 @@ async def register_object_with_provider(obj: RoutableObject, p: Any) -> Routable
|
||||||
raise ValueError(f"Unknown API {api} for registering object with provider")
|
raise ValueError(f"Unknown API {api} for registering object with provider")
|
||||||
|
|
||||||
|
|
||||||
|
async def unregister_object_from_provider(obj: RoutableObject, p: Any) -> None:
|
||||||
|
api = get_impl_api(p)
|
||||||
|
if api == Api.memory:
|
||||||
|
return await p.unregister_memory_bank(obj.identifier)
|
||||||
|
elif api == Api.inference:
|
||||||
|
return await p.unregister_model(obj.identifier)
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unregister not supported for {api}")
|
||||||
|
|
||||||
|
|
||||||
Registry = Dict[str, List[RoutableObjectWithProvider]]
|
Registry = Dict[str, List[RoutableObjectWithProvider]]
|
||||||
|
|
||||||
|
|
||||||
|
@ -148,17 +158,11 @@ class CommonRoutingTableImpl(RoutingTable):
|
||||||
|
|
||||||
return obj
|
return obj
|
||||||
|
|
||||||
async def delete_object(self, obj: RoutableObjectWithProvider) -> None:
|
async def unregister_object(self, obj: RoutableObjectWithProvider) -> None:
|
||||||
await self.dist_registry.delete(obj.type, obj.identifier)
|
await self.dist_registry.delete(obj.type, obj.identifier)
|
||||||
# TODO: delete from provider
|
await unregister_object_from_provider(
|
||||||
|
|
||||||
async def update_object(
|
|
||||||
self, obj: RoutableObjectWithProvider
|
|
||||||
) -> RoutableObjectWithProvider:
|
|
||||||
registered_obj = await register_object_with_provider(
|
|
||||||
obj, self.impls_by_provider_id[obj.provider_id]
|
obj, self.impls_by_provider_id[obj.provider_id]
|
||||||
)
|
)
|
||||||
return await self.dist_registry.update(registered_obj)
|
|
||||||
|
|
||||||
async def register_object(
|
async def register_object(
|
||||||
self, obj: RoutableObjectWithProvider
|
self, obj: RoutableObjectWithProvider
|
||||||
|
@ -232,32 +236,11 @@ class ModelsRoutingTable(CommonRoutingTableImpl, Models):
|
||||||
registered_model = await self.register_object(model)
|
registered_model = await self.register_object(model)
|
||||||
return registered_model
|
return registered_model
|
||||||
|
|
||||||
async def update_model(
|
async def unregister_model(self, model_id: str) -> None:
|
||||||
self,
|
|
||||||
model_id: str,
|
|
||||||
provider_model_id: Optional[str] = None,
|
|
||||||
provider_id: Optional[str] = None,
|
|
||||||
metadata: Optional[Dict[str, Any]] = None,
|
|
||||||
) -> Model:
|
|
||||||
existing_model = await self.get_model(model_id)
|
existing_model = await self.get_model(model_id)
|
||||||
if existing_model is None:
|
if existing_model is None:
|
||||||
raise ValueError(f"Model {model_id} not found")
|
raise ValueError(f"Model {model_id} not found")
|
||||||
|
await self.unregister_object(existing_model)
|
||||||
updated_model = Model(
|
|
||||||
identifier=model_id,
|
|
||||||
provider_resource_id=provider_model_id
|
|
||||||
or existing_model.provider_resource_id,
|
|
||||||
provider_id=provider_id or existing_model.provider_id,
|
|
||||||
metadata=metadata or existing_model.metadata,
|
|
||||||
)
|
|
||||||
registered_model = await self.update_object(updated_model)
|
|
||||||
return registered_model
|
|
||||||
|
|
||||||
async def delete_model(self, model_id: str) -> None:
|
|
||||||
existing_model = await self.get_model(model_id)
|
|
||||||
if existing_model is None:
|
|
||||||
raise ValueError(f"Model {model_id} not found")
|
|
||||||
await self.delete_object(existing_model)
|
|
||||||
|
|
||||||
|
|
||||||
class ShieldsRoutingTable(CommonRoutingTableImpl, Shields):
|
class ShieldsRoutingTable(CommonRoutingTableImpl, Shields):
|
||||||
|
@ -333,6 +316,12 @@ class MemoryBanksRoutingTable(CommonRoutingTableImpl, MemoryBanks):
|
||||||
await self.register_object(memory_bank)
|
await self.register_object(memory_bank)
|
||||||
return memory_bank
|
return memory_bank
|
||||||
|
|
||||||
|
async def unregister_memory_bank(self, memory_bank_id: str) -> None:
|
||||||
|
existing_bank = await self.get_memory_bank(memory_bank_id)
|
||||||
|
if existing_bank is None:
|
||||||
|
raise ValueError(f"Memory bank {memory_bank_id} not found")
|
||||||
|
await self.unregister_object(existing_bank)
|
||||||
|
|
||||||
|
|
||||||
class DatasetsRoutingTable(CommonRoutingTableImpl, Datasets):
|
class DatasetsRoutingTable(CommonRoutingTableImpl, Datasets):
|
||||||
async def list_datasets(self) -> List[Dataset]:
|
async def list_datasets(self) -> List[Dataset]:
|
||||||
|
|
|
@ -45,6 +45,8 @@ class Api(Enum):
|
||||||
class ModelsProtocolPrivate(Protocol):
|
class ModelsProtocolPrivate(Protocol):
|
||||||
async def register_model(self, model: Model) -> None: ...
|
async def register_model(self, model: Model) -> None: ...
|
||||||
|
|
||||||
|
async def unregister_model(self, model_id: str) -> None: ...
|
||||||
|
|
||||||
|
|
||||||
class ShieldsProtocolPrivate(Protocol):
|
class ShieldsProtocolPrivate(Protocol):
|
||||||
async def register_shield(self, shield: Shield) -> None: ...
|
async def register_shield(self, shield: Shield) -> None: ...
|
||||||
|
@ -55,6 +57,8 @@ class MemoryBanksProtocolPrivate(Protocol):
|
||||||
|
|
||||||
async def register_memory_bank(self, memory_bank: MemoryBank) -> None: ...
|
async def register_memory_bank(self, memory_bank: MemoryBank) -> None: ...
|
||||||
|
|
||||||
|
async def unregister_memory_bank(self, memory_bank_id: str) -> None: ...
|
||||||
|
|
||||||
|
|
||||||
class DatasetsProtocolPrivate(Protocol):
|
class DatasetsProtocolPrivate(Protocol):
|
||||||
async def register_dataset(self, dataset: Dataset) -> None: ...
|
async def register_dataset(self, dataset: Dataset) -> None: ...
|
||||||
|
|
|
@ -71,6 +71,9 @@ class MetaReferenceInferenceImpl(Inference, ModelRegistryHelper, ModelsProtocolP
|
||||||
f"Model mismatch: {request.model} != {self.model.descriptor()}"
|
f"Model mismatch: {request.model} != {self.model.descriptor()}"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
async def unregister_model(self, model_id: str) -> None:
|
||||||
|
pass
|
||||||
|
|
||||||
async def completion(
|
async def completion(
|
||||||
self,
|
self,
|
||||||
model_id: str,
|
model_id: str,
|
||||||
|
|
|
@ -108,6 +108,9 @@ class VLLMInferenceImpl(Inference, ModelsProtocolPrivate):
|
||||||
|
|
||||||
return VLLMSamplingParams(**kwargs)
|
return VLLMSamplingParams(**kwargs)
|
||||||
|
|
||||||
|
async def unregister_model(self, model_id: str) -> None:
|
||||||
|
pass
|
||||||
|
|
||||||
async def completion(
|
async def completion(
|
||||||
self,
|
self,
|
||||||
model_id: str,
|
model_id: str,
|
||||||
|
|
|
@ -4,6 +4,8 @@
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
import base64
|
||||||
|
import json
|
||||||
import logging
|
import logging
|
||||||
|
|
||||||
from typing import Any, Dict, List, Optional
|
from typing import Any, Dict, List, Optional
|
||||||
|
@ -37,10 +39,52 @@ class FaissIndex(EmbeddingIndex):
|
||||||
id_by_index: Dict[int, str]
|
id_by_index: Dict[int, str]
|
||||||
chunk_by_index: Dict[int, str]
|
chunk_by_index: Dict[int, str]
|
||||||
|
|
||||||
def __init__(self, dimension: int):
|
def __init__(self, dimension: int, kvstore=None, bank_id: str = None):
|
||||||
self.index = faiss.IndexFlatL2(dimension)
|
self.index = faiss.IndexFlatL2(dimension)
|
||||||
self.id_by_index = {}
|
self.id_by_index = {}
|
||||||
self.chunk_by_index = {}
|
self.chunk_by_index = {}
|
||||||
|
self.kvstore = kvstore
|
||||||
|
self.bank_id = bank_id
|
||||||
|
self.initialize()
|
||||||
|
|
||||||
|
async def initialize(self) -> None:
|
||||||
|
if not self.kvstore:
|
||||||
|
return
|
||||||
|
|
||||||
|
index_key = f"faiss_index:v1::{self.bank_id}"
|
||||||
|
stored_data = await self.kvstore.get(index_key)
|
||||||
|
|
||||||
|
if stored_data:
|
||||||
|
data = json.loads(stored_data)
|
||||||
|
self.id_by_index = {int(k): v for k, v in data["id_by_index"].items()}
|
||||||
|
self.chunk_by_index = {
|
||||||
|
int(k): Chunk.model_validate_json(v)
|
||||||
|
for k, v in data["chunk_by_index"].items()
|
||||||
|
}
|
||||||
|
|
||||||
|
index_bytes = base64.b64decode(data["faiss_index"])
|
||||||
|
self.index = faiss.deserialize_index(index_bytes)
|
||||||
|
|
||||||
|
async def _save_index(self):
|
||||||
|
if not self.kvstore or not self.bank_id:
|
||||||
|
return
|
||||||
|
|
||||||
|
index_bytes = faiss.serialize_index(self.index)
|
||||||
|
|
||||||
|
data = {
|
||||||
|
"id_by_index": self.id_by_index,
|
||||||
|
"chunk_by_index": {k: v.json() for k, v in self.chunk_by_index.items()},
|
||||||
|
"faiss_index": base64.b64encode(index_bytes).decode(),
|
||||||
|
}
|
||||||
|
|
||||||
|
index_key = f"faiss_index:v1::{self.bank_id}"
|
||||||
|
await self.kvstore.set(key=index_key, value=json.dumps(data))
|
||||||
|
|
||||||
|
async def delete(self):
|
||||||
|
if not self.kvstore or not self.bank_id:
|
||||||
|
return
|
||||||
|
|
||||||
|
await self.kvstore.delete(f"faiss_index:v1::{self.bank_id}")
|
||||||
|
|
||||||
@tracing.span(name="add_chunks")
|
@tracing.span(name="add_chunks")
|
||||||
async def add_chunks(self, chunks: List[Chunk], embeddings: NDArray):
|
async def add_chunks(self, chunks: List[Chunk], embeddings: NDArray):
|
||||||
|
@ -51,6 +95,9 @@ class FaissIndex(EmbeddingIndex):
|
||||||
|
|
||||||
self.index.add(np.array(embeddings).astype(np.float32))
|
self.index.add(np.array(embeddings).astype(np.float32))
|
||||||
|
|
||||||
|
# Save updated index
|
||||||
|
await self._save_index()
|
||||||
|
|
||||||
async def query(
|
async def query(
|
||||||
self, embedding: NDArray, k: int, score_threshold: float
|
self, embedding: NDArray, k: int, score_threshold: float
|
||||||
) -> QueryDocumentsResponse:
|
) -> QueryDocumentsResponse:
|
||||||
|
@ -85,7 +132,7 @@ class FaissMemoryImpl(Memory, MemoryBanksProtocolPrivate):
|
||||||
for bank_data in stored_banks:
|
for bank_data in stored_banks:
|
||||||
bank = VectorMemoryBank.model_validate_json(bank_data)
|
bank = VectorMemoryBank.model_validate_json(bank_data)
|
||||||
index = BankWithIndex(
|
index = BankWithIndex(
|
||||||
bank=bank, index=FaissIndex(ALL_MINILM_L6_V2_DIMENSION)
|
bank=bank, index=FaissIndex(ALL_MINILM_L6_V2_DIMENSION, self.kvstore)
|
||||||
)
|
)
|
||||||
self.cache[bank.identifier] = index
|
self.cache[bank.identifier] = index
|
||||||
|
|
||||||
|
@ -110,13 +157,19 @@ class FaissMemoryImpl(Memory, MemoryBanksProtocolPrivate):
|
||||||
|
|
||||||
# Store in cache
|
# Store in cache
|
||||||
index = BankWithIndex(
|
index = BankWithIndex(
|
||||||
bank=memory_bank, index=FaissIndex(ALL_MINILM_L6_V2_DIMENSION)
|
bank=memory_bank,
|
||||||
|
index=FaissIndex(ALL_MINILM_L6_V2_DIMENSION, self.kvstore),
|
||||||
)
|
)
|
||||||
self.cache[memory_bank.identifier] = index
|
self.cache[memory_bank.identifier] = index
|
||||||
|
|
||||||
async def list_memory_banks(self) -> List[MemoryBank]:
|
async def list_memory_banks(self) -> List[MemoryBank]:
|
||||||
return [i.bank for i in self.cache.values()]
|
return [i.bank for i in self.cache.values()]
|
||||||
|
|
||||||
|
async def unregister_memory_bank(self, memory_bank_id: str) -> None:
|
||||||
|
await self.cache[memory_bank_id].index.delete()
|
||||||
|
del self.cache[memory_bank_id]
|
||||||
|
await self.kvstore.delete(f"{MEMORY_BANKS_PREFIX}{memory_bank_id}")
|
||||||
|
|
||||||
async def insert_documents(
|
async def insert_documents(
|
||||||
self,
|
self,
|
||||||
bank_id: str,
|
bank_id: str,
|
||||||
|
|
|
@ -93,6 +93,9 @@ class OllamaInferenceAdapter(Inference, ModelsProtocolPrivate):
|
||||||
async def shutdown(self) -> None:
|
async def shutdown(self) -> None:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
async def unregister_model(self, model_id: str) -> None:
|
||||||
|
pass
|
||||||
|
|
||||||
async def completion(
|
async def completion(
|
||||||
self,
|
self,
|
||||||
model_id: str,
|
model_id: str,
|
||||||
|
|
|
@ -69,6 +69,9 @@ class _HfAdapter(Inference, ModelsProtocolPrivate):
|
||||||
async def shutdown(self) -> None:
|
async def shutdown(self) -> None:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
async def unregister_model(self, model_id: str) -> None:
|
||||||
|
pass
|
||||||
|
|
||||||
async def completion(
|
async def completion(
|
||||||
self,
|
self,
|
||||||
model: str,
|
model: str,
|
||||||
|
|
|
@ -58,6 +58,9 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate):
|
||||||
async def shutdown(self) -> None:
|
async def shutdown(self) -> None:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
async def unregister_model(self, model_id: str) -> None:
|
||||||
|
pass
|
||||||
|
|
||||||
async def completion(
|
async def completion(
|
||||||
self,
|
self,
|
||||||
model_id: str,
|
model_id: str,
|
||||||
|
|
|
@ -67,6 +67,9 @@ class ChromaIndex(EmbeddingIndex):
|
||||||
|
|
||||||
return QueryDocumentsResponse(chunks=chunks, scores=scores)
|
return QueryDocumentsResponse(chunks=chunks, scores=scores)
|
||||||
|
|
||||||
|
async def delete(self):
|
||||||
|
await self.client.delete_collection(self.collection.name)
|
||||||
|
|
||||||
|
|
||||||
class ChromaMemoryAdapter(Memory, MemoryBanksProtocolPrivate):
|
class ChromaMemoryAdapter(Memory, MemoryBanksProtocolPrivate):
|
||||||
def __init__(self, url: str) -> None:
|
def __init__(self, url: str) -> None:
|
||||||
|
@ -134,6 +137,10 @@ class ChromaMemoryAdapter(Memory, MemoryBanksProtocolPrivate):
|
||||||
|
|
||||||
return [i.bank for i in self.cache.values()]
|
return [i.bank for i in self.cache.values()]
|
||||||
|
|
||||||
|
async def unregister_memory_bank(self, memory_bank_id: str) -> None:
|
||||||
|
await self.cache[memory_bank_id].index.delete()
|
||||||
|
del self.cache[memory_bank_id]
|
||||||
|
|
||||||
async def insert_documents(
|
async def insert_documents(
|
||||||
self,
|
self,
|
||||||
bank_id: str,
|
bank_id: str,
|
||||||
|
|
|
@ -112,6 +112,9 @@ class PGVectorIndex(EmbeddingIndex):
|
||||||
|
|
||||||
return QueryDocumentsResponse(chunks=chunks, scores=scores)
|
return QueryDocumentsResponse(chunks=chunks, scores=scores)
|
||||||
|
|
||||||
|
async def delete(self):
|
||||||
|
self.cursor.execute(f"DROP TABLE IF EXISTS {self.table_name}")
|
||||||
|
|
||||||
|
|
||||||
class PGVectorMemoryAdapter(Memory, MemoryBanksProtocolPrivate):
|
class PGVectorMemoryAdapter(Memory, MemoryBanksProtocolPrivate):
|
||||||
def __init__(self, config: PGVectorConfig) -> None:
|
def __init__(self, config: PGVectorConfig) -> None:
|
||||||
|
@ -177,6 +180,10 @@ class PGVectorMemoryAdapter(Memory, MemoryBanksProtocolPrivate):
|
||||||
)
|
)
|
||||||
self.cache[memory_bank.identifier] = index
|
self.cache[memory_bank.identifier] = index
|
||||||
|
|
||||||
|
async def unregister_memory_bank(self, memory_bank_id: str) -> None:
|
||||||
|
await self.cache[memory_bank_id].index.delete()
|
||||||
|
del self.cache[memory_bank_id]
|
||||||
|
|
||||||
async def list_memory_banks(self) -> List[MemoryBank]:
|
async def list_memory_banks(self) -> List[MemoryBank]:
|
||||||
banks = load_models(self.cursor, VectorMemoryBank)
|
banks = load_models(self.cursor, VectorMemoryBank)
|
||||||
for bank in banks:
|
for bank in banks:
|
||||||
|
|
|
@ -54,4 +54,4 @@ class TestModelRegistration:
|
||||||
assert updated_model.provider_resource_id != old_model.provider_resource_id
|
assert updated_model.provider_resource_id != old_model.provider_resource_id
|
||||||
|
|
||||||
# Cleanup
|
# Cleanup
|
||||||
await models_impl.delete_model(model_id=model_id)
|
await models_impl.unregister_model(model_id=model_id)
|
||||||
|
|
|
@ -4,6 +4,8 @@
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
import uuid
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from llama_stack.apis.memory import * # noqa: F403
|
from llama_stack.apis.memory import * # noqa: F403
|
||||||
|
@ -43,9 +45,10 @@ def sample_documents():
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
async def register_memory_bank(banks_impl: MemoryBanks):
|
async def register_memory_bank(banks_impl: MemoryBanks) -> MemoryBank:
|
||||||
|
bank_id = f"test_bank_{uuid.uuid4().hex}"
|
||||||
return await banks_impl.register_memory_bank(
|
return await banks_impl.register_memory_bank(
|
||||||
memory_bank_id="test_bank",
|
memory_bank_id=bank_id,
|
||||||
params=VectorMemoryBankParams(
|
params=VectorMemoryBankParams(
|
||||||
embedding_model="all-MiniLM-L6-v2",
|
embedding_model="all-MiniLM-L6-v2",
|
||||||
chunk_size_in_tokens=512,
|
chunk_size_in_tokens=512,
|
||||||
|
@ -57,43 +60,70 @@ async def register_memory_bank(banks_impl: MemoryBanks):
|
||||||
class TestMemory:
|
class TestMemory:
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_banks_list(self, memory_stack):
|
async def test_banks_list(self, memory_stack):
|
||||||
# NOTE: this needs you to ensure that you are starting from a clean state
|
|
||||||
# but so far we don't have an unregister API unfortunately, so be careful
|
|
||||||
_, banks_impl = memory_stack
|
_, banks_impl = memory_stack
|
||||||
|
|
||||||
|
# Register a test bank
|
||||||
|
registered_bank = await register_memory_bank(banks_impl)
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Verify our bank shows up in list
|
||||||
response = await banks_impl.list_memory_banks()
|
response = await banks_impl.list_memory_banks()
|
||||||
assert isinstance(response, list)
|
assert isinstance(response, list)
|
||||||
assert len(response) == 0
|
assert any(
|
||||||
|
bank.memory_bank_id == registered_bank.memory_bank_id
|
||||||
|
for bank in response
|
||||||
|
)
|
||||||
|
finally:
|
||||||
|
# Clean up
|
||||||
|
await banks_impl.unregister_memory_bank(registered_bank.memory_bank_id)
|
||||||
|
|
||||||
|
# Verify our bank was removed
|
||||||
|
response = await banks_impl.list_memory_banks()
|
||||||
|
assert all(
|
||||||
|
bank.memory_bank_id != registered_bank.memory_bank_id for bank in response
|
||||||
|
)
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_banks_register(self, memory_stack):
|
async def test_banks_register(self, memory_stack):
|
||||||
# NOTE: this needs you to ensure that you are starting from a clean state
|
|
||||||
# but so far we don't have an unregister API unfortunately, so be careful
|
|
||||||
_, banks_impl = memory_stack
|
_, banks_impl = memory_stack
|
||||||
|
|
||||||
await banks_impl.register_memory_bank(
|
bank_id = f"test_bank_{uuid.uuid4().hex}"
|
||||||
memory_bank_id="test_bank_no_provider",
|
|
||||||
params=VectorMemoryBankParams(
|
|
||||||
embedding_model="all-MiniLM-L6-v2",
|
|
||||||
chunk_size_in_tokens=512,
|
|
||||||
overlap_size_in_tokens=64,
|
|
||||||
),
|
|
||||||
)
|
|
||||||
response = await banks_impl.list_memory_banks()
|
|
||||||
assert isinstance(response, list)
|
|
||||||
assert len(response) == 1
|
|
||||||
|
|
||||||
# register same memory bank with same id again will fail
|
try:
|
||||||
|
# Register initial bank
|
||||||
await banks_impl.register_memory_bank(
|
await banks_impl.register_memory_bank(
|
||||||
memory_bank_id="test_bank_no_provider",
|
memory_bank_id=bank_id,
|
||||||
params=VectorMemoryBankParams(
|
params=VectorMemoryBankParams(
|
||||||
embedding_model="all-MiniLM-L6-v2",
|
embedding_model="all-MiniLM-L6-v2",
|
||||||
chunk_size_in_tokens=512,
|
chunk_size_in_tokens=512,
|
||||||
overlap_size_in_tokens=64,
|
overlap_size_in_tokens=64,
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Verify our bank exists
|
||||||
response = await banks_impl.list_memory_banks()
|
response = await banks_impl.list_memory_banks()
|
||||||
assert isinstance(response, list)
|
assert isinstance(response, list)
|
||||||
assert len(response) == 1
|
assert any(bank.memory_bank_id == bank_id for bank in response)
|
||||||
|
|
||||||
|
# Try registering same bank again
|
||||||
|
await banks_impl.register_memory_bank(
|
||||||
|
memory_bank_id=bank_id,
|
||||||
|
params=VectorMemoryBankParams(
|
||||||
|
embedding_model="all-MiniLM-L6-v2",
|
||||||
|
chunk_size_in_tokens=512,
|
||||||
|
overlap_size_in_tokens=64,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
# Verify still only one instance of our bank
|
||||||
|
response = await banks_impl.list_memory_banks()
|
||||||
|
assert isinstance(response, list)
|
||||||
|
assert (
|
||||||
|
len([bank for bank in response if bank.memory_bank_id == bank_id]) == 1
|
||||||
|
)
|
||||||
|
finally:
|
||||||
|
# Clean up
|
||||||
|
await banks_impl.unregister_memory_bank(bank_id)
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_query_documents(self, memory_stack, sample_documents):
|
async def test_query_documents(self, memory_stack, sample_documents):
|
||||||
|
@ -102,17 +132,23 @@ class TestMemory:
|
||||||
with pytest.raises(ValueError):
|
with pytest.raises(ValueError):
|
||||||
await memory_impl.insert_documents("test_bank", sample_documents)
|
await memory_impl.insert_documents("test_bank", sample_documents)
|
||||||
|
|
||||||
await register_memory_bank(banks_impl)
|
registered_bank = await register_memory_bank(banks_impl)
|
||||||
await memory_impl.insert_documents("test_bank", sample_documents)
|
await memory_impl.insert_documents(
|
||||||
|
registered_bank.memory_bank_id, sample_documents
|
||||||
|
)
|
||||||
|
|
||||||
query1 = "programming language"
|
query1 = "programming language"
|
||||||
response1 = await memory_impl.query_documents("test_bank", query1)
|
response1 = await memory_impl.query_documents(
|
||||||
|
registered_bank.memory_bank_id, query1
|
||||||
|
)
|
||||||
assert_valid_response(response1)
|
assert_valid_response(response1)
|
||||||
assert any("Python" in chunk.content for chunk in response1.chunks)
|
assert any("Python" in chunk.content for chunk in response1.chunks)
|
||||||
|
|
||||||
# Test case 3: Query with semantic similarity
|
# Test case 3: Query with semantic similarity
|
||||||
query3 = "AI and brain-inspired computing"
|
query3 = "AI and brain-inspired computing"
|
||||||
response3 = await memory_impl.query_documents("test_bank", query3)
|
response3 = await memory_impl.query_documents(
|
||||||
|
registered_bank.memory_bank_id, query3
|
||||||
|
)
|
||||||
assert_valid_response(response3)
|
assert_valid_response(response3)
|
||||||
assert any(
|
assert any(
|
||||||
"neural networks" in chunk.content.lower() for chunk in response3.chunks
|
"neural networks" in chunk.content.lower() for chunk in response3.chunks
|
||||||
|
@ -121,14 +157,18 @@ class TestMemory:
|
||||||
# Test case 4: Query with limit on number of results
|
# Test case 4: Query with limit on number of results
|
||||||
query4 = "computer"
|
query4 = "computer"
|
||||||
params4 = {"max_chunks": 2}
|
params4 = {"max_chunks": 2}
|
||||||
response4 = await memory_impl.query_documents("test_bank", query4, params4)
|
response4 = await memory_impl.query_documents(
|
||||||
|
registered_bank.memory_bank_id, query4, params4
|
||||||
|
)
|
||||||
assert_valid_response(response4)
|
assert_valid_response(response4)
|
||||||
assert len(response4.chunks) <= 2
|
assert len(response4.chunks) <= 2
|
||||||
|
|
||||||
# Test case 5: Query with threshold on similarity score
|
# Test case 5: Query with threshold on similarity score
|
||||||
query5 = "quantum computing" # Not directly related to any document
|
query5 = "quantum computing" # Not directly related to any document
|
||||||
params5 = {"score_threshold": 0.2}
|
params5 = {"score_threshold": 0.2}
|
||||||
response5 = await memory_impl.query_documents("test_bank", query5, params5)
|
response5 = await memory_impl.query_documents(
|
||||||
|
registered_bank.memory_bank_id, query5, params5
|
||||||
|
)
|
||||||
assert_valid_response(response5)
|
assert_valid_response(response5)
|
||||||
print("The scores are:", response5.scores)
|
print("The scores are:", response5.scores)
|
||||||
assert all(score >= 0.2 for score in response5.scores)
|
assert all(score >= 0.2 for score in response5.scores)
|
||||||
|
|
|
@ -145,6 +145,10 @@ class EmbeddingIndex(ABC):
|
||||||
) -> QueryDocumentsResponse:
|
) -> QueryDocumentsResponse:
|
||||||
raise NotImplementedError()
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
async def delete(self):
|
||||||
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class BankWithIndex:
|
class BankWithIndex:
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue