mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-12-17 05:52:36 +00:00
nuke updates
This commit is contained in:
parent
690e525a36
commit
aa93eeb2b7
15 changed files with 15 additions and 429 deletions
|
|
@ -145,14 +145,5 @@ class MemoryBanks(Protocol):
|
|||
provider_memory_bank_id: Optional[str] = None,
|
||||
) -> MemoryBank: ...
|
||||
|
||||
@webmethod(route="/memory_banks/update", method="POST")
|
||||
async def update_memory_bank(
|
||||
self,
|
||||
memory_bank_id: str,
|
||||
params: BankParams,
|
||||
provider_id: Optional[str] = None,
|
||||
provider_memory_bank_id: Optional[str] = None,
|
||||
) -> MemoryBank: ...
|
||||
|
||||
@webmethod(route="/memory_banks/unregister", method="POST")
|
||||
async def unregister_memory_bank(self, memory_bank_id: str) -> None: ...
|
||||
|
|
|
|||
|
|
@ -7,7 +7,7 @@
|
|||
import asyncio
|
||||
import json
|
||||
|
||||
from typing import Any, Dict, List, Optional
|
||||
from typing import List, Optional
|
||||
|
||||
import fire
|
||||
import httpx
|
||||
|
|
@ -61,27 +61,6 @@ class ModelsClient(Models):
|
|||
return None
|
||||
return Model(**j)
|
||||
|
||||
async def update_model(
|
||||
self,
|
||||
model_id: str,
|
||||
provider_model_id: Optional[str] = None,
|
||||
provider_id: Optional[str] = None,
|
||||
metadata: Optional[Dict[str, Any]] = None,
|
||||
) -> Model:
|
||||
async with httpx.AsyncClient() as client:
|
||||
response = await client.put(
|
||||
f"{self.base_url}/models/update",
|
||||
json={
|
||||
"model_id": model_id,
|
||||
"provider_model_id": provider_model_id,
|
||||
"provider_id": provider_id,
|
||||
"metadata": metadata,
|
||||
},
|
||||
headers={"Content-Type": "application/json"},
|
||||
)
|
||||
response.raise_for_status()
|
||||
return Model(**response.json())
|
||||
|
||||
async def unregister_model(self, model_id: str) -> None:
|
||||
async with httpx.AsyncClient() as client:
|
||||
response = await client.delete(
|
||||
|
|
|
|||
|
|
@ -55,14 +55,5 @@ class Models(Protocol):
|
|||
metadata: Optional[Dict[str, Any]] = None,
|
||||
) -> Model: ...
|
||||
|
||||
@webmethod(route="/models/update", method="POST")
|
||||
async def update_model(
|
||||
self,
|
||||
model_id: str,
|
||||
provider_model_id: Optional[str] = None,
|
||||
provider_id: Optional[str] = None,
|
||||
metadata: Optional[Dict[str, Any]] = None,
|
||||
) -> Model: ...
|
||||
|
||||
@webmethod(route="/models/unregister", method="POST")
|
||||
async def unregister_model(self, model_id: str) -> None: ...
|
||||
|
|
|
|||
|
|
@ -51,18 +51,6 @@ async def register_object_with_provider(obj: RoutableObject, p: Any) -> Routable
|
|||
raise ValueError(f"Unknown API {api} for registering object with provider")
|
||||
|
||||
|
||||
async def update_object_with_provider(
|
||||
obj: RoutableObject, p: Any
|
||||
) -> Optional[RoutableObject]:
|
||||
api = get_impl_api(p)
|
||||
if api == Api.memory:
|
||||
return await p.update_memory_bank(obj)
|
||||
elif api == Api.inference:
|
||||
return await p.update_model(obj)
|
||||
else:
|
||||
raise ValueError(f"Update not supported for {api}")
|
||||
|
||||
|
||||
async def unregister_object_from_provider(obj: RoutableObject, p: Any) -> None:
|
||||
api = get_impl_api(p)
|
||||
if api == Api.memory:
|
||||
|
|
@ -176,14 +164,6 @@ class CommonRoutingTableImpl(RoutingTable):
|
|||
obj, self.impls_by_provider_id[obj.provider_id]
|
||||
)
|
||||
|
||||
async def update_object(
|
||||
self, obj: RoutableObjectWithProvider
|
||||
) -> RoutableObjectWithProvider:
|
||||
registered_obj = await update_object_with_provider(
|
||||
obj, self.impls_by_provider_id[obj.provider_id]
|
||||
)
|
||||
return await self.dist_registry.update(registered_obj or obj)
|
||||
|
||||
async def register_object(
|
||||
self, obj: RoutableObjectWithProvider
|
||||
) -> RoutableObjectWithProvider:
|
||||
|
|
@ -256,27 +236,6 @@ class ModelsRoutingTable(CommonRoutingTableImpl, Models):
|
|||
registered_model = await self.register_object(model)
|
||||
return registered_model
|
||||
|
||||
async def update_model(
|
||||
self,
|
||||
model_id: str,
|
||||
provider_model_id: Optional[str] = None,
|
||||
provider_id: Optional[str] = None,
|
||||
metadata: Optional[Dict[str, Any]] = None,
|
||||
) -> Model:
|
||||
existing_model = await self.get_model(model_id)
|
||||
if existing_model is None:
|
||||
raise ValueError(f"Model {model_id} not found")
|
||||
|
||||
updated_model = Model(
|
||||
identifier=model_id,
|
||||
provider_resource_id=provider_model_id
|
||||
or existing_model.provider_resource_id,
|
||||
provider_id=provider_id or existing_model.provider_id,
|
||||
metadata=metadata or existing_model.metadata,
|
||||
)
|
||||
registered_model = await self.update_object(updated_model)
|
||||
return registered_model
|
||||
|
||||
async def unregister_model(self, model_id: str) -> None:
|
||||
existing_model = await self.get_model(model_id)
|
||||
if existing_model is None:
|
||||
|
|
@ -357,31 +316,6 @@ class MemoryBanksRoutingTable(CommonRoutingTableImpl, MemoryBanks):
|
|||
await self.register_object(memory_bank)
|
||||
return memory_bank
|
||||
|
||||
async def update_memory_bank(
|
||||
self,
|
||||
memory_bank_id: str,
|
||||
params: BankParams,
|
||||
provider_id: Optional[str] = None,
|
||||
provider_memory_bank_id: Optional[str] = None,
|
||||
) -> MemoryBank:
|
||||
existing_bank = await self.get_memory_bank(memory_bank_id)
|
||||
if existing_bank is None:
|
||||
raise ValueError(f"Memory bank {memory_bank_id} not found")
|
||||
|
||||
updated_bank = parse_obj_as(
|
||||
MemoryBank,
|
||||
{
|
||||
"identifier": memory_bank_id,
|
||||
"type": ResourceType.memory_bank.value,
|
||||
"provider_id": provider_id or existing_bank.provider_id,
|
||||
"provider_resource_id": provider_memory_bank_id
|
||||
or existing_bank.provider_resource_id,
|
||||
**params.model_dump(),
|
||||
},
|
||||
)
|
||||
registered_bank = await self.update_object(updated_bank)
|
||||
return registered_bank
|
||||
|
||||
async def unregister_memory_bank(self, memory_bank_id: str) -> None:
|
||||
existing_bank = await self.get_memory_bank(memory_bank_id)
|
||||
if existing_bank is None:
|
||||
|
|
|
|||
|
|
@ -45,8 +45,6 @@ class Api(Enum):
|
|||
class ModelsProtocolPrivate(Protocol):
|
||||
async def register_model(self, model: Model) -> None: ...
|
||||
|
||||
async def update_model(self, model: Model) -> None: ...
|
||||
|
||||
async def unregister_model(self, model_id: str) -> None: ...
|
||||
|
||||
|
||||
|
|
@ -61,8 +59,6 @@ class MemoryBanksProtocolPrivate(Protocol):
|
|||
|
||||
async def unregister_memory_bank(self, memory_bank_id: str) -> None: ...
|
||||
|
||||
async def update_memory_bank(self, memory_bank: MemoryBank) -> None: ...
|
||||
|
||||
|
||||
class DatasetsProtocolPrivate(Protocol):
|
||||
async def register_dataset(self, dataset: Dataset) -> None: ...
|
||||
|
|
@ -107,6 +103,7 @@ class RoutingTable(Protocol):
|
|||
def get_provider_impl(self, routing_key: str) -> Any: ...
|
||||
|
||||
|
||||
# TODO: this can now be inlined into RemoteProviderSpec
|
||||
@json_schema_type
|
||||
class AdapterSpec(BaseModel):
|
||||
adapter_type: str = Field(
|
||||
|
|
@ -179,12 +176,10 @@ class RemoteProviderConfig(BaseModel):
|
|||
|
||||
@json_schema_type
|
||||
class RemoteProviderSpec(ProviderSpec):
|
||||
adapter: Optional[AdapterSpec] = Field(
|
||||
default=None,
|
||||
adapter: AdapterSpec = Field(
|
||||
description="""
|
||||
If some code is needed to convert the remote responses into Llama Stack compatible
|
||||
API responses, specify the adapter here. If not specified, it indicates the remote
|
||||
as being "Llama Stack compatible"
|
||||
API responses, specify the adapter here.
|
||||
""",
|
||||
)
|
||||
|
||||
|
|
@ -194,38 +189,21 @@ as being "Llama Stack compatible"
|
|||
|
||||
@property
|
||||
def module(self) -> str:
|
||||
if self.adapter:
|
||||
return self.adapter.module
|
||||
return "llama_stack.distribution.client"
|
||||
return self.adapter.module
|
||||
|
||||
@property
|
||||
def pip_packages(self) -> List[str]:
|
||||
if self.adapter:
|
||||
return self.adapter.pip_packages
|
||||
return []
|
||||
return self.adapter.pip_packages
|
||||
|
||||
@property
|
||||
def provider_data_validator(self) -> Optional[str]:
|
||||
if self.adapter:
|
||||
return self.adapter.provider_data_validator
|
||||
return None
|
||||
return self.adapter.provider_data_validator
|
||||
|
||||
|
||||
def is_passthrough(spec: ProviderSpec) -> bool:
|
||||
return isinstance(spec, RemoteProviderSpec) and spec.adapter is None
|
||||
|
||||
|
||||
# Can avoid this by using Pydantic computed_field
|
||||
def remote_provider_spec(
|
||||
api: Api, adapter: Optional[AdapterSpec] = None
|
||||
) -> RemoteProviderSpec:
|
||||
config_class = (
|
||||
adapter.config_class
|
||||
if adapter and adapter.config_class
|
||||
else "llama_stack.distribution.datatypes.RemoteProviderConfig"
|
||||
)
|
||||
provider_type = f"remote::{adapter.adapter_type}" if adapter else "remote"
|
||||
|
||||
def remote_provider_spec(api: Api, adapter: AdapterSpec) -> RemoteProviderSpec:
|
||||
return RemoteProviderSpec(
|
||||
api=api, provider_type=provider_type, config_class=config_class, adapter=adapter
|
||||
api=api,
|
||||
provider_type=f"remote::{adapter.adapter_type}",
|
||||
config_class=adapter.config_class,
|
||||
adapter=adapter,
|
||||
)
|
||||
|
|
|
|||
|
|
@ -71,9 +71,6 @@ class MetaReferenceInferenceImpl(Inference, ModelRegistryHelper, ModelsProtocolP
|
|||
f"Model mismatch: {request.model} != {self.model.descriptor()}"
|
||||
)
|
||||
|
||||
async def update_model(self, model: Model) -> None:
|
||||
pass
|
||||
|
||||
async def unregister_model(self, model_id: str) -> None:
|
||||
pass
|
||||
|
||||
|
|
|
|||
|
|
@ -108,9 +108,6 @@ class VLLMInferenceImpl(Inference, ModelsProtocolPrivate):
|
|||
|
||||
return VLLMSamplingParams(**kwargs)
|
||||
|
||||
async def update_model(self, model: Model) -> None:
|
||||
pass
|
||||
|
||||
async def unregister_model(self, model_id: str) -> None:
|
||||
pass
|
||||
|
||||
|
|
|
|||
|
|
@ -48,10 +48,9 @@ class FaissIndex(EmbeddingIndex):
|
|||
self.initialize()
|
||||
|
||||
async def initialize(self) -> None:
|
||||
if not self.kvstore or not self.bank_id:
|
||||
if not self.kvstore:
|
||||
return
|
||||
|
||||
# Load existing index data from kvstore
|
||||
index_key = f"faiss_index:v1::{self.bank_id}"
|
||||
stored_data = await self.kvstore.get(index_key)
|
||||
|
||||
|
|
@ -63,7 +62,6 @@ class FaissIndex(EmbeddingIndex):
|
|||
for k, v in data["chunk_by_index"].items()
|
||||
}
|
||||
|
||||
# Load FAISS index
|
||||
index_bytes = base64.b64decode(data["faiss_index"])
|
||||
self.index = faiss.deserialize_index(index_bytes)
|
||||
|
||||
|
|
@ -71,17 +69,14 @@ class FaissIndex(EmbeddingIndex):
|
|||
if not self.kvstore or not self.bank_id:
|
||||
return
|
||||
|
||||
# Serialize FAISS index
|
||||
index_bytes = faiss.serialize_index(self.index)
|
||||
|
||||
# Prepare data for storage
|
||||
data = {
|
||||
"id_by_index": self.id_by_index,
|
||||
"chunk_by_index": {k: v.json() for k, v in self.chunk_by_index.items()},
|
||||
"faiss_index": base64.b64encode(index_bytes).decode(),
|
||||
}
|
||||
|
||||
# Store in kvstore
|
||||
index_key = f"faiss_index:v1::{self.bank_id}"
|
||||
await self.kvstore.set(key=index_key, value=json.dumps(data))
|
||||
|
||||
|
|
@ -175,15 +170,6 @@ class FaissMemoryImpl(Memory, MemoryBanksProtocolPrivate):
|
|||
del self.cache[memory_bank_id]
|
||||
await self.kvstore.delete(f"{MEMORY_BANKS_PREFIX}{memory_bank_id}")
|
||||
|
||||
async def update_memory_bank(self, memory_bank: MemoryBank) -> None:
|
||||
# Not possible to update the index in place, so we delete and recreate
|
||||
await self.cache[memory_bank.identifier].index.delete()
|
||||
|
||||
self.cache[memory_bank.identifier] = BankWithIndex(
|
||||
bank=memory_bank,
|
||||
index=FaissIndex(ALL_MINILM_L6_V2_DIMENSION, self.kvstore),
|
||||
)
|
||||
|
||||
async def insert_documents(
|
||||
self,
|
||||
bank_id: str,
|
||||
|
|
|
|||
|
|
@ -93,9 +93,6 @@ class OllamaInferenceAdapter(Inference, ModelsProtocolPrivate):
|
|||
async def shutdown(self) -> None:
|
||||
pass
|
||||
|
||||
async def update_model(self, model: Model) -> None:
|
||||
pass
|
||||
|
||||
async def unregister_model(self, model_id: str) -> None:
|
||||
pass
|
||||
|
||||
|
|
|
|||
|
|
@ -69,9 +69,6 @@ class _HfAdapter(Inference, ModelsProtocolPrivate):
|
|||
async def shutdown(self) -> None:
|
||||
pass
|
||||
|
||||
async def update_model(self, model: Model) -> None:
|
||||
pass
|
||||
|
||||
async def unregister_model(self, model_id: str) -> None:
|
||||
pass
|
||||
|
||||
|
|
|
|||
|
|
@ -58,9 +58,6 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate):
|
|||
async def shutdown(self) -> None:
|
||||
pass
|
||||
|
||||
async def update_model(self, model: Model) -> None:
|
||||
pass
|
||||
|
||||
async def unregister_model(self, model_id: str) -> None:
|
||||
pass
|
||||
|
||||
|
|
|
|||
|
|
@ -141,10 +141,6 @@ class ChromaMemoryAdapter(Memory, MemoryBanksProtocolPrivate):
|
|||
await self.cache[memory_bank_id].index.delete()
|
||||
del self.cache[memory_bank_id]
|
||||
|
||||
async def update_memory_bank(self, memory_bank: MemoryBank) -> None:
|
||||
await self.unregister_memory_bank(memory_bank.identifier)
|
||||
await self.register_memory_bank(memory_bank)
|
||||
|
||||
async def insert_documents(
|
||||
self,
|
||||
bank_id: str,
|
||||
|
|
|
|||
|
|
@ -184,10 +184,6 @@ class PGVectorMemoryAdapter(Memory, MemoryBanksProtocolPrivate):
|
|||
await self.cache[memory_bank_id].index.delete()
|
||||
del self.cache[memory_bank_id]
|
||||
|
||||
async def update_memory_bank(self, memory_bank: MemoryBank) -> None:
|
||||
await self.unregister_memory_bank(memory_bank.identifier)
|
||||
await self.register_memory_bank(memory_bank)
|
||||
|
||||
async def list_memory_banks(self) -> List[MemoryBank]:
|
||||
banks = load_models(self.cursor, VectorMemoryBank)
|
||||
for bank in banks:
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue