mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-08-01 16:24:44 +00:00
nuke updates
This commit is contained in:
parent
690e525a36
commit
aa93eeb2b7
15 changed files with 15 additions and 429 deletions
|
@ -21,7 +21,7 @@
|
||||||
"info": {
|
"info": {
|
||||||
"title": "[DRAFT] Llama Stack Specification",
|
"title": "[DRAFT] Llama Stack Specification",
|
||||||
"version": "0.0.1",
|
"version": "0.0.1",
|
||||||
"description": "This is the specification of the llama stack that provides\n a set of endpoints and their corresponding interfaces that are tailored to\n best leverage Llama Models. The specification is still in draft and subject to change.\n Generated at 2024-11-14 16:18:00.903125"
|
"description": "This is the specification of the llama stack that provides\n a set of endpoints and their corresponding interfaces that are tailored to\n best leverage Llama Models. The specification is still in draft and subject to change.\n Generated at 2024-11-14 17:04:24.301559"
|
||||||
},
|
},
|
||||||
"servers": [
|
"servers": [
|
||||||
{
|
{
|
||||||
|
@ -2291,75 +2291,6 @@
|
||||||
"required": true
|
"required": true
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
},
|
|
||||||
"/memory_banks/update": {
|
|
||||||
"post": {
|
|
||||||
"responses": {},
|
|
||||||
"tags": [
|
|
||||||
"MemoryBanks"
|
|
||||||
],
|
|
||||||
"parameters": [
|
|
||||||
{
|
|
||||||
"name": "X-LlamaStack-ProviderData",
|
|
||||||
"in": "header",
|
|
||||||
"description": "JSON-encoded provider data which will be made available to the adapter servicing the API",
|
|
||||||
"required": false,
|
|
||||||
"schema": {
|
|
||||||
"type": "string"
|
|
||||||
}
|
|
||||||
}
|
|
||||||
],
|
|
||||||
"requestBody": {
|
|
||||||
"content": {
|
|
||||||
"application/json": {
|
|
||||||
"schema": {
|
|
||||||
"$ref": "#/components/schemas/UpdateMemoryBankRequest"
|
|
||||||
}
|
|
||||||
}
|
|
||||||
},
|
|
||||||
"required": true
|
|
||||||
}
|
|
||||||
}
|
|
||||||
},
|
|
||||||
"/models/update": {
|
|
||||||
"post": {
|
|
||||||
"responses": {
|
|
||||||
"200": {
|
|
||||||
"description": "OK",
|
|
||||||
"content": {
|
|
||||||
"application/json": {
|
|
||||||
"schema": {
|
|
||||||
"$ref": "#/components/schemas/Model"
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
},
|
|
||||||
"tags": [
|
|
||||||
"Models"
|
|
||||||
],
|
|
||||||
"parameters": [
|
|
||||||
{
|
|
||||||
"name": "X-LlamaStack-ProviderData",
|
|
||||||
"in": "header",
|
|
||||||
"description": "JSON-encoded provider data which will be made available to the adapter servicing the API",
|
|
||||||
"required": false,
|
|
||||||
"schema": {
|
|
||||||
"type": "string"
|
|
||||||
}
|
|
||||||
}
|
|
||||||
],
|
|
||||||
"requestBody": {
|
|
||||||
"content": {
|
|
||||||
"application/json": {
|
|
||||||
"schema": {
|
|
||||||
"$ref": "#/components/schemas/UpdateModelRequest"
|
|
||||||
}
|
|
||||||
}
|
|
||||||
},
|
|
||||||
"required": true
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
"jsonSchemaDialect": "https://json-schema.org/draft/2020-12/schema",
|
"jsonSchemaDialect": "https://json-schema.org/draft/2020-12/schema",
|
||||||
|
@ -7985,84 +7916,6 @@
|
||||||
"required": [
|
"required": [
|
||||||
"model_id"
|
"model_id"
|
||||||
]
|
]
|
||||||
},
|
|
||||||
"UpdateMemoryBankRequest": {
|
|
||||||
"type": "object",
|
|
||||||
"properties": {
|
|
||||||
"memory_bank_id": {
|
|
||||||
"type": "string"
|
|
||||||
},
|
|
||||||
"params": {
|
|
||||||
"oneOf": [
|
|
||||||
{
|
|
||||||
"$ref": "#/components/schemas/VectorMemoryBankParams"
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"$ref": "#/components/schemas/KeyValueMemoryBankParams"
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"$ref": "#/components/schemas/KeywordMemoryBankParams"
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"$ref": "#/components/schemas/GraphMemoryBankParams"
|
|
||||||
}
|
|
||||||
]
|
|
||||||
},
|
|
||||||
"provider_id": {
|
|
||||||
"type": "string"
|
|
||||||
},
|
|
||||||
"provider_memory_bank_id": {
|
|
||||||
"type": "string"
|
|
||||||
}
|
|
||||||
},
|
|
||||||
"additionalProperties": false,
|
|
||||||
"required": [
|
|
||||||
"memory_bank_id",
|
|
||||||
"params"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
"UpdateModelRequest": {
|
|
||||||
"type": "object",
|
|
||||||
"properties": {
|
|
||||||
"model_id": {
|
|
||||||
"type": "string"
|
|
||||||
},
|
|
||||||
"provider_model_id": {
|
|
||||||
"type": "string"
|
|
||||||
},
|
|
||||||
"provider_id": {
|
|
||||||
"type": "string"
|
|
||||||
},
|
|
||||||
"metadata": {
|
|
||||||
"type": "object",
|
|
||||||
"additionalProperties": {
|
|
||||||
"oneOf": [
|
|
||||||
{
|
|
||||||
"type": "null"
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"type": "boolean"
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"type": "number"
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"type": "string"
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"type": "array"
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"type": "object"
|
|
||||||
}
|
|
||||||
]
|
|
||||||
}
|
|
||||||
}
|
|
||||||
},
|
|
||||||
"additionalProperties": false,
|
|
||||||
"required": [
|
|
||||||
"model_id"
|
|
||||||
]
|
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
"responses": {}
|
"responses": {}
|
||||||
|
@ -8679,14 +8532,6 @@
|
||||||
"name": "UnstructuredLogEvent",
|
"name": "UnstructuredLogEvent",
|
||||||
"description": "<SchemaDefinition schemaRef=\"#/components/schemas/UnstructuredLogEvent\" />"
|
"description": "<SchemaDefinition schemaRef=\"#/components/schemas/UnstructuredLogEvent\" />"
|
||||||
},
|
},
|
||||||
{
|
|
||||||
"name": "UpdateMemoryBankRequest",
|
|
||||||
"description": "<SchemaDefinition schemaRef=\"#/components/schemas/UpdateMemoryBankRequest\" />"
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"name": "UpdateModelRequest",
|
|
||||||
"description": "<SchemaDefinition schemaRef=\"#/components/schemas/UpdateModelRequest\" />"
|
|
||||||
},
|
|
||||||
{
|
{
|
||||||
"name": "UserMessage",
|
"name": "UserMessage",
|
||||||
"description": "<SchemaDefinition schemaRef=\"#/components/schemas/UserMessage\" />"
|
"description": "<SchemaDefinition schemaRef=\"#/components/schemas/UserMessage\" />"
|
||||||
|
@ -8873,8 +8718,6 @@
|
||||||
"UnregisterMemoryBankRequest",
|
"UnregisterMemoryBankRequest",
|
||||||
"UnregisterModelRequest",
|
"UnregisterModelRequest",
|
||||||
"UnstructuredLogEvent",
|
"UnstructuredLogEvent",
|
||||||
"UpdateMemoryBankRequest",
|
|
||||||
"UpdateModelRequest",
|
|
||||||
"UserMessage",
|
"UserMessage",
|
||||||
"VectorMemoryBank",
|
"VectorMemoryBank",
|
||||||
"VectorMemoryBankParams",
|
"VectorMemoryBankParams",
|
||||||
|
|
|
@ -3288,47 +3288,6 @@ components:
|
||||||
- message
|
- message
|
||||||
- severity
|
- severity
|
||||||
type: object
|
type: object
|
||||||
UpdateMemoryBankRequest:
|
|
||||||
additionalProperties: false
|
|
||||||
properties:
|
|
||||||
memory_bank_id:
|
|
||||||
type: string
|
|
||||||
params:
|
|
||||||
oneOf:
|
|
||||||
- $ref: '#/components/schemas/VectorMemoryBankParams'
|
|
||||||
- $ref: '#/components/schemas/KeyValueMemoryBankParams'
|
|
||||||
- $ref: '#/components/schemas/KeywordMemoryBankParams'
|
|
||||||
- $ref: '#/components/schemas/GraphMemoryBankParams'
|
|
||||||
provider_id:
|
|
||||||
type: string
|
|
||||||
provider_memory_bank_id:
|
|
||||||
type: string
|
|
||||||
required:
|
|
||||||
- memory_bank_id
|
|
||||||
- params
|
|
||||||
type: object
|
|
||||||
UpdateModelRequest:
|
|
||||||
additionalProperties: false
|
|
||||||
properties:
|
|
||||||
metadata:
|
|
||||||
additionalProperties:
|
|
||||||
oneOf:
|
|
||||||
- type: 'null'
|
|
||||||
- type: boolean
|
|
||||||
- type: number
|
|
||||||
- type: string
|
|
||||||
- type: array
|
|
||||||
- type: object
|
|
||||||
type: object
|
|
||||||
model_id:
|
|
||||||
type: string
|
|
||||||
provider_id:
|
|
||||||
type: string
|
|
||||||
provider_model_id:
|
|
||||||
type: string
|
|
||||||
required:
|
|
||||||
- model_id
|
|
||||||
type: object
|
|
||||||
UserMessage:
|
UserMessage:
|
||||||
additionalProperties: false
|
additionalProperties: false
|
||||||
properties:
|
properties:
|
||||||
|
@ -3441,7 +3400,7 @@ info:
|
||||||
description: "This is the specification of the llama stack that provides\n \
|
description: "This is the specification of the llama stack that provides\n \
|
||||||
\ a set of endpoints and their corresponding interfaces that are tailored\
|
\ a set of endpoints and their corresponding interfaces that are tailored\
|
||||||
\ to\n best leverage Llama Models. The specification is still in\
|
\ to\n best leverage Llama Models. The specification is still in\
|
||||||
\ draft and subject to change.\n Generated at 2024-11-14 16:18:00.903125"
|
\ draft and subject to change.\n Generated at 2024-11-14 17:04:24.301559"
|
||||||
title: '[DRAFT] Llama Stack Specification'
|
title: '[DRAFT] Llama Stack Specification'
|
||||||
version: 0.0.1
|
version: 0.0.1
|
||||||
jsonSchemaDialect: https://json-schema.org/draft/2020-12/schema
|
jsonSchemaDialect: https://json-schema.org/draft/2020-12/schema
|
||||||
|
@ -4264,25 +4223,6 @@ paths:
|
||||||
description: OK
|
description: OK
|
||||||
tags:
|
tags:
|
||||||
- MemoryBanks
|
- MemoryBanks
|
||||||
/memory_banks/update:
|
|
||||||
post:
|
|
||||||
parameters:
|
|
||||||
- description: JSON-encoded provider data which will be made available to the
|
|
||||||
adapter servicing the API
|
|
||||||
in: header
|
|
||||||
name: X-LlamaStack-ProviderData
|
|
||||||
required: false
|
|
||||||
schema:
|
|
||||||
type: string
|
|
||||||
requestBody:
|
|
||||||
content:
|
|
||||||
application/json:
|
|
||||||
schema:
|
|
||||||
$ref: '#/components/schemas/UpdateMemoryBankRequest'
|
|
||||||
required: true
|
|
||||||
responses: {}
|
|
||||||
tags:
|
|
||||||
- MemoryBanks
|
|
||||||
/models/get:
|
/models/get:
|
||||||
get:
|
get:
|
||||||
parameters:
|
parameters:
|
||||||
|
@ -4374,31 +4314,6 @@ paths:
|
||||||
description: OK
|
description: OK
|
||||||
tags:
|
tags:
|
||||||
- Models
|
- Models
|
||||||
/models/update:
|
|
||||||
post:
|
|
||||||
parameters:
|
|
||||||
- description: JSON-encoded provider data which will be made available to the
|
|
||||||
adapter servicing the API
|
|
||||||
in: header
|
|
||||||
name: X-LlamaStack-ProviderData
|
|
||||||
required: false
|
|
||||||
schema:
|
|
||||||
type: string
|
|
||||||
requestBody:
|
|
||||||
content:
|
|
||||||
application/json:
|
|
||||||
schema:
|
|
||||||
$ref: '#/components/schemas/UpdateModelRequest'
|
|
||||||
required: true
|
|
||||||
responses:
|
|
||||||
'200':
|
|
||||||
content:
|
|
||||||
application/json:
|
|
||||||
schema:
|
|
||||||
$ref: '#/components/schemas/Model'
|
|
||||||
description: OK
|
|
||||||
tags:
|
|
||||||
- Models
|
|
||||||
/post_training/job/artifacts:
|
/post_training/job/artifacts:
|
||||||
get:
|
get:
|
||||||
parameters:
|
parameters:
|
||||||
|
@ -5330,12 +5245,6 @@ tags:
|
||||||
- description: <SchemaDefinition schemaRef="#/components/schemas/UnstructuredLogEvent"
|
- description: <SchemaDefinition schemaRef="#/components/schemas/UnstructuredLogEvent"
|
||||||
/>
|
/>
|
||||||
name: UnstructuredLogEvent
|
name: UnstructuredLogEvent
|
||||||
- description: <SchemaDefinition schemaRef="#/components/schemas/UpdateMemoryBankRequest"
|
|
||||||
/>
|
|
||||||
name: UpdateMemoryBankRequest
|
|
||||||
- description: <SchemaDefinition schemaRef="#/components/schemas/UpdateModelRequest"
|
|
||||||
/>
|
|
||||||
name: UpdateModelRequest
|
|
||||||
- description: <SchemaDefinition schemaRef="#/components/schemas/UserMessage" />
|
- description: <SchemaDefinition schemaRef="#/components/schemas/UserMessage" />
|
||||||
name: UserMessage
|
name: UserMessage
|
||||||
- description: <SchemaDefinition schemaRef="#/components/schemas/VectorMemoryBank"
|
- description: <SchemaDefinition schemaRef="#/components/schemas/VectorMemoryBank"
|
||||||
|
@ -5510,8 +5419,6 @@ x-tagGroups:
|
||||||
- UnregisterMemoryBankRequest
|
- UnregisterMemoryBankRequest
|
||||||
- UnregisterModelRequest
|
- UnregisterModelRequest
|
||||||
- UnstructuredLogEvent
|
- UnstructuredLogEvent
|
||||||
- UpdateMemoryBankRequest
|
|
||||||
- UpdateModelRequest
|
|
||||||
- UserMessage
|
- UserMessage
|
||||||
- VectorMemoryBank
|
- VectorMemoryBank
|
||||||
- VectorMemoryBankParams
|
- VectorMemoryBankParams
|
||||||
|
|
|
@ -145,14 +145,5 @@ class MemoryBanks(Protocol):
|
||||||
provider_memory_bank_id: Optional[str] = None,
|
provider_memory_bank_id: Optional[str] = None,
|
||||||
) -> MemoryBank: ...
|
) -> MemoryBank: ...
|
||||||
|
|
||||||
@webmethod(route="/memory_banks/update", method="POST")
|
|
||||||
async def update_memory_bank(
|
|
||||||
self,
|
|
||||||
memory_bank_id: str,
|
|
||||||
params: BankParams,
|
|
||||||
provider_id: Optional[str] = None,
|
|
||||||
provider_memory_bank_id: Optional[str] = None,
|
|
||||||
) -> MemoryBank: ...
|
|
||||||
|
|
||||||
@webmethod(route="/memory_banks/unregister", method="POST")
|
@webmethod(route="/memory_banks/unregister", method="POST")
|
||||||
async def unregister_memory_bank(self, memory_bank_id: str) -> None: ...
|
async def unregister_memory_bank(self, memory_bank_id: str) -> None: ...
|
||||||
|
|
|
@ -7,7 +7,7 @@
|
||||||
import asyncio
|
import asyncio
|
||||||
import json
|
import json
|
||||||
|
|
||||||
from typing import Any, Dict, List, Optional
|
from typing import List, Optional
|
||||||
|
|
||||||
import fire
|
import fire
|
||||||
import httpx
|
import httpx
|
||||||
|
@ -61,27 +61,6 @@ class ModelsClient(Models):
|
||||||
return None
|
return None
|
||||||
return Model(**j)
|
return Model(**j)
|
||||||
|
|
||||||
async def update_model(
|
|
||||||
self,
|
|
||||||
model_id: str,
|
|
||||||
provider_model_id: Optional[str] = None,
|
|
||||||
provider_id: Optional[str] = None,
|
|
||||||
metadata: Optional[Dict[str, Any]] = None,
|
|
||||||
) -> Model:
|
|
||||||
async with httpx.AsyncClient() as client:
|
|
||||||
response = await client.put(
|
|
||||||
f"{self.base_url}/models/update",
|
|
||||||
json={
|
|
||||||
"model_id": model_id,
|
|
||||||
"provider_model_id": provider_model_id,
|
|
||||||
"provider_id": provider_id,
|
|
||||||
"metadata": metadata,
|
|
||||||
},
|
|
||||||
headers={"Content-Type": "application/json"},
|
|
||||||
)
|
|
||||||
response.raise_for_status()
|
|
||||||
return Model(**response.json())
|
|
||||||
|
|
||||||
async def unregister_model(self, model_id: str) -> None:
|
async def unregister_model(self, model_id: str) -> None:
|
||||||
async with httpx.AsyncClient() as client:
|
async with httpx.AsyncClient() as client:
|
||||||
response = await client.delete(
|
response = await client.delete(
|
||||||
|
|
|
@ -55,14 +55,5 @@ class Models(Protocol):
|
||||||
metadata: Optional[Dict[str, Any]] = None,
|
metadata: Optional[Dict[str, Any]] = None,
|
||||||
) -> Model: ...
|
) -> Model: ...
|
||||||
|
|
||||||
@webmethod(route="/models/update", method="POST")
|
|
||||||
async def update_model(
|
|
||||||
self,
|
|
||||||
model_id: str,
|
|
||||||
provider_model_id: Optional[str] = None,
|
|
||||||
provider_id: Optional[str] = None,
|
|
||||||
metadata: Optional[Dict[str, Any]] = None,
|
|
||||||
) -> Model: ...
|
|
||||||
|
|
||||||
@webmethod(route="/models/unregister", method="POST")
|
@webmethod(route="/models/unregister", method="POST")
|
||||||
async def unregister_model(self, model_id: str) -> None: ...
|
async def unregister_model(self, model_id: str) -> None: ...
|
||||||
|
|
|
@ -51,18 +51,6 @@ async def register_object_with_provider(obj: RoutableObject, p: Any) -> Routable
|
||||||
raise ValueError(f"Unknown API {api} for registering object with provider")
|
raise ValueError(f"Unknown API {api} for registering object with provider")
|
||||||
|
|
||||||
|
|
||||||
async def update_object_with_provider(
|
|
||||||
obj: RoutableObject, p: Any
|
|
||||||
) -> Optional[RoutableObject]:
|
|
||||||
api = get_impl_api(p)
|
|
||||||
if api == Api.memory:
|
|
||||||
return await p.update_memory_bank(obj)
|
|
||||||
elif api == Api.inference:
|
|
||||||
return await p.update_model(obj)
|
|
||||||
else:
|
|
||||||
raise ValueError(f"Update not supported for {api}")
|
|
||||||
|
|
||||||
|
|
||||||
async def unregister_object_from_provider(obj: RoutableObject, p: Any) -> None:
|
async def unregister_object_from_provider(obj: RoutableObject, p: Any) -> None:
|
||||||
api = get_impl_api(p)
|
api = get_impl_api(p)
|
||||||
if api == Api.memory:
|
if api == Api.memory:
|
||||||
|
@ -176,14 +164,6 @@ class CommonRoutingTableImpl(RoutingTable):
|
||||||
obj, self.impls_by_provider_id[obj.provider_id]
|
obj, self.impls_by_provider_id[obj.provider_id]
|
||||||
)
|
)
|
||||||
|
|
||||||
async def update_object(
|
|
||||||
self, obj: RoutableObjectWithProvider
|
|
||||||
) -> RoutableObjectWithProvider:
|
|
||||||
registered_obj = await update_object_with_provider(
|
|
||||||
obj, self.impls_by_provider_id[obj.provider_id]
|
|
||||||
)
|
|
||||||
return await self.dist_registry.update(registered_obj or obj)
|
|
||||||
|
|
||||||
async def register_object(
|
async def register_object(
|
||||||
self, obj: RoutableObjectWithProvider
|
self, obj: RoutableObjectWithProvider
|
||||||
) -> RoutableObjectWithProvider:
|
) -> RoutableObjectWithProvider:
|
||||||
|
@ -256,27 +236,6 @@ class ModelsRoutingTable(CommonRoutingTableImpl, Models):
|
||||||
registered_model = await self.register_object(model)
|
registered_model = await self.register_object(model)
|
||||||
return registered_model
|
return registered_model
|
||||||
|
|
||||||
async def update_model(
|
|
||||||
self,
|
|
||||||
model_id: str,
|
|
||||||
provider_model_id: Optional[str] = None,
|
|
||||||
provider_id: Optional[str] = None,
|
|
||||||
metadata: Optional[Dict[str, Any]] = None,
|
|
||||||
) -> Model:
|
|
||||||
existing_model = await self.get_model(model_id)
|
|
||||||
if existing_model is None:
|
|
||||||
raise ValueError(f"Model {model_id} not found")
|
|
||||||
|
|
||||||
updated_model = Model(
|
|
||||||
identifier=model_id,
|
|
||||||
provider_resource_id=provider_model_id
|
|
||||||
or existing_model.provider_resource_id,
|
|
||||||
provider_id=provider_id or existing_model.provider_id,
|
|
||||||
metadata=metadata or existing_model.metadata,
|
|
||||||
)
|
|
||||||
registered_model = await self.update_object(updated_model)
|
|
||||||
return registered_model
|
|
||||||
|
|
||||||
async def unregister_model(self, model_id: str) -> None:
|
async def unregister_model(self, model_id: str) -> None:
|
||||||
existing_model = await self.get_model(model_id)
|
existing_model = await self.get_model(model_id)
|
||||||
if existing_model is None:
|
if existing_model is None:
|
||||||
|
@ -357,31 +316,6 @@ class MemoryBanksRoutingTable(CommonRoutingTableImpl, MemoryBanks):
|
||||||
await self.register_object(memory_bank)
|
await self.register_object(memory_bank)
|
||||||
return memory_bank
|
return memory_bank
|
||||||
|
|
||||||
async def update_memory_bank(
|
|
||||||
self,
|
|
||||||
memory_bank_id: str,
|
|
||||||
params: BankParams,
|
|
||||||
provider_id: Optional[str] = None,
|
|
||||||
provider_memory_bank_id: Optional[str] = None,
|
|
||||||
) -> MemoryBank:
|
|
||||||
existing_bank = await self.get_memory_bank(memory_bank_id)
|
|
||||||
if existing_bank is None:
|
|
||||||
raise ValueError(f"Memory bank {memory_bank_id} not found")
|
|
||||||
|
|
||||||
updated_bank = parse_obj_as(
|
|
||||||
MemoryBank,
|
|
||||||
{
|
|
||||||
"identifier": memory_bank_id,
|
|
||||||
"type": ResourceType.memory_bank.value,
|
|
||||||
"provider_id": provider_id or existing_bank.provider_id,
|
|
||||||
"provider_resource_id": provider_memory_bank_id
|
|
||||||
or existing_bank.provider_resource_id,
|
|
||||||
**params.model_dump(),
|
|
||||||
},
|
|
||||||
)
|
|
||||||
registered_bank = await self.update_object(updated_bank)
|
|
||||||
return registered_bank
|
|
||||||
|
|
||||||
async def unregister_memory_bank(self, memory_bank_id: str) -> None:
|
async def unregister_memory_bank(self, memory_bank_id: str) -> None:
|
||||||
existing_bank = await self.get_memory_bank(memory_bank_id)
|
existing_bank = await self.get_memory_bank(memory_bank_id)
|
||||||
if existing_bank is None:
|
if existing_bank is None:
|
||||||
|
|
|
@ -45,8 +45,6 @@ class Api(Enum):
|
||||||
class ModelsProtocolPrivate(Protocol):
|
class ModelsProtocolPrivate(Protocol):
|
||||||
async def register_model(self, model: Model) -> None: ...
|
async def register_model(self, model: Model) -> None: ...
|
||||||
|
|
||||||
async def update_model(self, model: Model) -> None: ...
|
|
||||||
|
|
||||||
async def unregister_model(self, model_id: str) -> None: ...
|
async def unregister_model(self, model_id: str) -> None: ...
|
||||||
|
|
||||||
|
|
||||||
|
@ -61,8 +59,6 @@ class MemoryBanksProtocolPrivate(Protocol):
|
||||||
|
|
||||||
async def unregister_memory_bank(self, memory_bank_id: str) -> None: ...
|
async def unregister_memory_bank(self, memory_bank_id: str) -> None: ...
|
||||||
|
|
||||||
async def update_memory_bank(self, memory_bank: MemoryBank) -> None: ...
|
|
||||||
|
|
||||||
|
|
||||||
class DatasetsProtocolPrivate(Protocol):
|
class DatasetsProtocolPrivate(Protocol):
|
||||||
async def register_dataset(self, dataset: Dataset) -> None: ...
|
async def register_dataset(self, dataset: Dataset) -> None: ...
|
||||||
|
@ -107,6 +103,7 @@ class RoutingTable(Protocol):
|
||||||
def get_provider_impl(self, routing_key: str) -> Any: ...
|
def get_provider_impl(self, routing_key: str) -> Any: ...
|
||||||
|
|
||||||
|
|
||||||
|
# TODO: this can now be inlined into RemoteProviderSpec
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
class AdapterSpec(BaseModel):
|
class AdapterSpec(BaseModel):
|
||||||
adapter_type: str = Field(
|
adapter_type: str = Field(
|
||||||
|
@ -179,12 +176,10 @@ class RemoteProviderConfig(BaseModel):
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
class RemoteProviderSpec(ProviderSpec):
|
class RemoteProviderSpec(ProviderSpec):
|
||||||
adapter: Optional[AdapterSpec] = Field(
|
adapter: AdapterSpec = Field(
|
||||||
default=None,
|
|
||||||
description="""
|
description="""
|
||||||
If some code is needed to convert the remote responses into Llama Stack compatible
|
If some code is needed to convert the remote responses into Llama Stack compatible
|
||||||
API responses, specify the adapter here. If not specified, it indicates the remote
|
API responses, specify the adapter here.
|
||||||
as being "Llama Stack compatible"
|
|
||||||
""",
|
""",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -194,38 +189,21 @@ as being "Llama Stack compatible"
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def module(self) -> str:
|
def module(self) -> str:
|
||||||
if self.adapter:
|
return self.adapter.module
|
||||||
return self.adapter.module
|
|
||||||
return "llama_stack.distribution.client"
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def pip_packages(self) -> List[str]:
|
def pip_packages(self) -> List[str]:
|
||||||
if self.adapter:
|
return self.adapter.pip_packages
|
||||||
return self.adapter.pip_packages
|
|
||||||
return []
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def provider_data_validator(self) -> Optional[str]:
|
def provider_data_validator(self) -> Optional[str]:
|
||||||
if self.adapter:
|
return self.adapter.provider_data_validator
|
||||||
return self.adapter.provider_data_validator
|
|
||||||
return None
|
|
||||||
|
|
||||||
|
|
||||||
def is_passthrough(spec: ProviderSpec) -> bool:
|
def remote_provider_spec(api: Api, adapter: AdapterSpec) -> RemoteProviderSpec:
|
||||||
return isinstance(spec, RemoteProviderSpec) and spec.adapter is None
|
|
||||||
|
|
||||||
|
|
||||||
# Can avoid this by using Pydantic computed_field
|
|
||||||
def remote_provider_spec(
|
|
||||||
api: Api, adapter: Optional[AdapterSpec] = None
|
|
||||||
) -> RemoteProviderSpec:
|
|
||||||
config_class = (
|
|
||||||
adapter.config_class
|
|
||||||
if adapter and adapter.config_class
|
|
||||||
else "llama_stack.distribution.datatypes.RemoteProviderConfig"
|
|
||||||
)
|
|
||||||
provider_type = f"remote::{adapter.adapter_type}" if adapter else "remote"
|
|
||||||
|
|
||||||
return RemoteProviderSpec(
|
return RemoteProviderSpec(
|
||||||
api=api, provider_type=provider_type, config_class=config_class, adapter=adapter
|
api=api,
|
||||||
|
provider_type=f"remote::{adapter.adapter_type}",
|
||||||
|
config_class=adapter.config_class,
|
||||||
|
adapter=adapter,
|
||||||
)
|
)
|
||||||
|
|
|
@ -71,9 +71,6 @@ class MetaReferenceInferenceImpl(Inference, ModelRegistryHelper, ModelsProtocolP
|
||||||
f"Model mismatch: {request.model} != {self.model.descriptor()}"
|
f"Model mismatch: {request.model} != {self.model.descriptor()}"
|
||||||
)
|
)
|
||||||
|
|
||||||
async def update_model(self, model: Model) -> None:
|
|
||||||
pass
|
|
||||||
|
|
||||||
async def unregister_model(self, model_id: str) -> None:
|
async def unregister_model(self, model_id: str) -> None:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
|
@ -108,9 +108,6 @@ class VLLMInferenceImpl(Inference, ModelsProtocolPrivate):
|
||||||
|
|
||||||
return VLLMSamplingParams(**kwargs)
|
return VLLMSamplingParams(**kwargs)
|
||||||
|
|
||||||
async def update_model(self, model: Model) -> None:
|
|
||||||
pass
|
|
||||||
|
|
||||||
async def unregister_model(self, model_id: str) -> None:
|
async def unregister_model(self, model_id: str) -> None:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
|
@ -48,10 +48,9 @@ class FaissIndex(EmbeddingIndex):
|
||||||
self.initialize()
|
self.initialize()
|
||||||
|
|
||||||
async def initialize(self) -> None:
|
async def initialize(self) -> None:
|
||||||
if not self.kvstore or not self.bank_id:
|
if not self.kvstore:
|
||||||
return
|
return
|
||||||
|
|
||||||
# Load existing index data from kvstore
|
|
||||||
index_key = f"faiss_index:v1::{self.bank_id}"
|
index_key = f"faiss_index:v1::{self.bank_id}"
|
||||||
stored_data = await self.kvstore.get(index_key)
|
stored_data = await self.kvstore.get(index_key)
|
||||||
|
|
||||||
|
@ -63,7 +62,6 @@ class FaissIndex(EmbeddingIndex):
|
||||||
for k, v in data["chunk_by_index"].items()
|
for k, v in data["chunk_by_index"].items()
|
||||||
}
|
}
|
||||||
|
|
||||||
# Load FAISS index
|
|
||||||
index_bytes = base64.b64decode(data["faiss_index"])
|
index_bytes = base64.b64decode(data["faiss_index"])
|
||||||
self.index = faiss.deserialize_index(index_bytes)
|
self.index = faiss.deserialize_index(index_bytes)
|
||||||
|
|
||||||
|
@ -71,17 +69,14 @@ class FaissIndex(EmbeddingIndex):
|
||||||
if not self.kvstore or not self.bank_id:
|
if not self.kvstore or not self.bank_id:
|
||||||
return
|
return
|
||||||
|
|
||||||
# Serialize FAISS index
|
|
||||||
index_bytes = faiss.serialize_index(self.index)
|
index_bytes = faiss.serialize_index(self.index)
|
||||||
|
|
||||||
# Prepare data for storage
|
|
||||||
data = {
|
data = {
|
||||||
"id_by_index": self.id_by_index,
|
"id_by_index": self.id_by_index,
|
||||||
"chunk_by_index": {k: v.json() for k, v in self.chunk_by_index.items()},
|
"chunk_by_index": {k: v.json() for k, v in self.chunk_by_index.items()},
|
||||||
"faiss_index": base64.b64encode(index_bytes).decode(),
|
"faiss_index": base64.b64encode(index_bytes).decode(),
|
||||||
}
|
}
|
||||||
|
|
||||||
# Store in kvstore
|
|
||||||
index_key = f"faiss_index:v1::{self.bank_id}"
|
index_key = f"faiss_index:v1::{self.bank_id}"
|
||||||
await self.kvstore.set(key=index_key, value=json.dumps(data))
|
await self.kvstore.set(key=index_key, value=json.dumps(data))
|
||||||
|
|
||||||
|
@ -175,15 +170,6 @@ class FaissMemoryImpl(Memory, MemoryBanksProtocolPrivate):
|
||||||
del self.cache[memory_bank_id]
|
del self.cache[memory_bank_id]
|
||||||
await self.kvstore.delete(f"{MEMORY_BANKS_PREFIX}{memory_bank_id}")
|
await self.kvstore.delete(f"{MEMORY_BANKS_PREFIX}{memory_bank_id}")
|
||||||
|
|
||||||
async def update_memory_bank(self, memory_bank: MemoryBank) -> None:
|
|
||||||
# Not possible to update the index in place, so we delete and recreate
|
|
||||||
await self.cache[memory_bank.identifier].index.delete()
|
|
||||||
|
|
||||||
self.cache[memory_bank.identifier] = BankWithIndex(
|
|
||||||
bank=memory_bank,
|
|
||||||
index=FaissIndex(ALL_MINILM_L6_V2_DIMENSION, self.kvstore),
|
|
||||||
)
|
|
||||||
|
|
||||||
async def insert_documents(
|
async def insert_documents(
|
||||||
self,
|
self,
|
||||||
bank_id: str,
|
bank_id: str,
|
||||||
|
|
|
@ -93,9 +93,6 @@ class OllamaInferenceAdapter(Inference, ModelsProtocolPrivate):
|
||||||
async def shutdown(self) -> None:
|
async def shutdown(self) -> None:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
async def update_model(self, model: Model) -> None:
|
|
||||||
pass
|
|
||||||
|
|
||||||
async def unregister_model(self, model_id: str) -> None:
|
async def unregister_model(self, model_id: str) -> None:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
|
@ -69,9 +69,6 @@ class _HfAdapter(Inference, ModelsProtocolPrivate):
|
||||||
async def shutdown(self) -> None:
|
async def shutdown(self) -> None:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
async def update_model(self, model: Model) -> None:
|
|
||||||
pass
|
|
||||||
|
|
||||||
async def unregister_model(self, model_id: str) -> None:
|
async def unregister_model(self, model_id: str) -> None:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
|
@ -58,9 +58,6 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate):
|
||||||
async def shutdown(self) -> None:
|
async def shutdown(self) -> None:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
async def update_model(self, model: Model) -> None:
|
|
||||||
pass
|
|
||||||
|
|
||||||
async def unregister_model(self, model_id: str) -> None:
|
async def unregister_model(self, model_id: str) -> None:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
|
@ -141,10 +141,6 @@ class ChromaMemoryAdapter(Memory, MemoryBanksProtocolPrivate):
|
||||||
await self.cache[memory_bank_id].index.delete()
|
await self.cache[memory_bank_id].index.delete()
|
||||||
del self.cache[memory_bank_id]
|
del self.cache[memory_bank_id]
|
||||||
|
|
||||||
async def update_memory_bank(self, memory_bank: MemoryBank) -> None:
|
|
||||||
await self.unregister_memory_bank(memory_bank.identifier)
|
|
||||||
await self.register_memory_bank(memory_bank)
|
|
||||||
|
|
||||||
async def insert_documents(
|
async def insert_documents(
|
||||||
self,
|
self,
|
||||||
bank_id: str,
|
bank_id: str,
|
||||||
|
|
|
@ -184,10 +184,6 @@ class PGVectorMemoryAdapter(Memory, MemoryBanksProtocolPrivate):
|
||||||
await self.cache[memory_bank_id].index.delete()
|
await self.cache[memory_bank_id].index.delete()
|
||||||
del self.cache[memory_bank_id]
|
del self.cache[memory_bank_id]
|
||||||
|
|
||||||
async def update_memory_bank(self, memory_bank: MemoryBank) -> None:
|
|
||||||
await self.unregister_memory_bank(memory_bank.identifier)
|
|
||||||
await self.register_memory_bank(memory_bank)
|
|
||||||
|
|
||||||
async def list_memory_banks(self) -> List[MemoryBank]:
|
async def list_memory_banks(self) -> List[MemoryBank]:
|
||||||
banks = load_models(self.cursor, VectorMemoryBank)
|
banks = load_models(self.cursor, VectorMemoryBank)
|
||||||
for bank in banks:
|
for bank in banks:
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue