persist registered objects with distribution

This commit is contained in:
Dinesh Yeduguru 2024-11-01 08:31:17 -07:00 committed by Dinesh Yeduguru
parent ac93dd89cf
commit 8fe22230b8
9 changed files with 85 additions and 16 deletions

View file

@ -10,10 +10,10 @@ from llama_models.llama3.api.datatypes import URL
from llama_models.schema_utils import json_schema_type, webmethod
from pydantic import BaseModel, Field
from llama_stack.apis.common.type_system import ParamType
from pydantic import BaseModel, Field
@json_schema_type
class DatasetDef(BaseModel):
@ -24,7 +24,7 @@ class DatasetDef(BaseModel):
description="The schema definition for this dataset",
)
url: URL
metadata: Dict[str, Any] = Field(
metadata: Dict[str, str] = Field(
default_factory=dict,
description="Any additional metadata for this dataset",
)

View file

@ -17,7 +17,7 @@ class ModelDef(BaseModel):
llama_model: str = Field(
description="Pointer to the underlying core Llama family model. Each model served by Llama Stack must have a core Llama model.",
)
metadata: Dict[str, Any] = Field(
metadata: Dict[str, str] = Field(
default_factory=dict,
description="Any additional metadata for this model",
)

View file

@ -4,12 +4,12 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Any, Dict, List, Optional, Protocol, runtime_checkable
from typing import Dict, List, Optional, Protocol, runtime_checkable
from llama_models.schema_utils import json_schema_type, webmethod
from pydantic import BaseModel, Field
from llama_stack.apis.common.type_system import ParamType
from llama_stack.apis.common.type_system import JsonType, ParamType
from pydantic import BaseModel, Field
@json_schema_type
@ -36,7 +36,7 @@ class LLMAsJudgeContext(BaseModel):
class ScoringFnDef(BaseModel):
identifier: str
description: Optional[str] = None
metadata: Dict[str, Any] = Field(
metadata: Dict[str, str] = Field(
default_factory=dict,
description="Any additional metadata for this definition",
)

View file

@ -26,7 +26,7 @@ class ShieldDef(BaseModel):
type: str = Field(
description="The type of shield this is; the value is one of the ShieldType enum"
)
params: Dict[str, Any] = Field(
params: Dict[str, str] = Field(
default_factory=dict,
description="Any additional parameters needed for this shield",
)

View file

@ -14,6 +14,7 @@ from llama_stack.apis.memory_banks import * # noqa: F403
from llama_stack.apis.datasets import * # noqa: F403
from llama_stack.distribution.datatypes import * # noqa: F403
import llama_stack.distribution.store as distribution_store
def get_impl_api(p: Any) -> Api:
@ -170,8 +171,7 @@ class CommonRoutingTableImpl(RoutingTable):
if obj.identifier not in self.registry:
self.registry[obj.identifier] = []
self.registry[obj.identifier].append(obj)
# TODO: persist this to a store
await distribution_store.REGISTRY.register(obj)
class ModelsRoutingTable(CommonRoutingTableImpl, Models):

View file

@ -22,25 +22,28 @@ import yaml
from fastapi import Body, FastAPI, HTTPException, Request, Response
from fastapi.exceptions import RequestValidationError
from fastapi.responses import JSONResponse, StreamingResponse
from pydantic import BaseModel, ValidationError
from termcolor import cprint
from typing_extensions import Annotated
from llama_stack.distribution.distribution import (
builtin_automatically_routed_apis,
get_provider_registry,
)
from llama_stack.distribution.utils.config_dirs import DISTRIBS_BASE_DIR
from llama_stack.providers.utils.telemetry.tracing import (
end_trace,
setup_logger,
SpanStatus,
start_trace,
)
from pydantic import BaseModel, ValidationError
from termcolor import cprint
from typing_extensions import Annotated
from llama_stack.distribution.datatypes import * # noqa: F403
import llama_stack.distribution.store as distribution_store
from llama_stack.distribution.request_headers import set_request_provider_data
from llama_stack.distribution.resolver import resolve_impls
from llama_stack.providers.utils.kvstore import kvstore_impl, SqliteKVStoreConfig
from .endpoints import get_all_api_endpoints
@ -278,6 +281,18 @@ def main(
config = StackRunConfig(**yaml.safe_load(fp))
app = FastAPI()
# instantiate kvstore for storing and retrieving distribution metadata
dist_kvstore = asyncio.run(
kvstore_impl(
SqliteKVStoreConfig(
db_path=(
DISTRIBS_BASE_DIR / config.image_name / "kvstore.db"
).as_posix()
)
)
)
distribution_store.REGISTRY = distribution_store.DiskRegistry(dist_kvstore)
impls = asyncio.run(resolve_impls(config, get_provider_registry()))
if Api.telemetry in impls:

View file

@ -0,0 +1,3 @@
from .registry import DiskRegistry, Registry
REGISTRY: Registry = None

View file

@ -0,0 +1,51 @@
import json
from typing import Protocol
from docs.openapi_generator.strong_typing.deserializer import create_deserializer
from docs.openapi_generator.strong_typing.serialization import object_to_json
from llama_stack.distribution.datatypes import RoutableObjectWithProvider
from llama_stack.providers.utils.kvstore import KVStore
class Registry(Protocol):
async def get(self, identifier: str) -> [RoutableObjectWithProvider]: ...
async def register(self, obj: RoutableObjectWithProvider) -> None: ...
KEY_FORMAT = "distributions:registry:{}"
DESERIALIZER = create_deserializer(RoutableObjectWithProvider)
class DiskRegistry(Registry):
def __init__(self, kvstore: KVStore):
self.kvstore = kvstore
async def get(self, identifier: str) -> [RoutableObjectWithProvider]:
# Get JSON string from kvstore
json_str = await self.kvstore.get(KEY_FORMAT.format(identifier))
if not json_str:
return []
# Parse JSON string into list of objects
objects_data = json.loads(json_str)
return [DESERIALIZER.parse(obj_str) for obj_str in objects_data]
async def register(self, obj: RoutableObjectWithProvider) -> None:
# Get existing objects for this identifier
existing_objects = await self.get(obj.identifier)
# Add new object to list
existing_objects.append(obj)
# Convert all objects to JSON strings and store as JSON array
objects_json = [
object_to_json(existing_object) for existing_object in existing_objects
]
await self.kvstore.set(
KEY_FORMAT.format(obj.identifier), json.dumps(objects_json)
)

View file

@ -81,7 +81,7 @@ class OllamaInferenceAdapter(Inference, ModelsProtocolPrivate):
identifier=llama_model,
llama_model=llama_model,
metadata={
"ollama_model": r["model"],
"ollama_model": str(r["model"]),
},
)
)