persist registered objects with distribution

This commit is contained in:
Dinesh Yeduguru 2024-11-01 08:31:17 -07:00 committed by Dinesh Yeduguru
parent ac93dd89cf
commit 8fe22230b8
9 changed files with 85 additions and 16 deletions

View file

@ -10,10 +10,10 @@ from llama_models.llama3.api.datatypes import URL
from llama_models.schema_utils import json_schema_type, webmethod from llama_models.schema_utils import json_schema_type, webmethod
from pydantic import BaseModel, Field
from llama_stack.apis.common.type_system import ParamType from llama_stack.apis.common.type_system import ParamType
from pydantic import BaseModel, Field
@json_schema_type @json_schema_type
class DatasetDef(BaseModel): class DatasetDef(BaseModel):
@ -24,7 +24,7 @@ class DatasetDef(BaseModel):
description="The schema definition for this dataset", description="The schema definition for this dataset",
) )
url: URL url: URL
metadata: Dict[str, Any] = Field( metadata: Dict[str, str] = Field(
default_factory=dict, default_factory=dict,
description="Any additional metadata for this dataset", description="Any additional metadata for this dataset",
) )

View file

@ -17,7 +17,7 @@ class ModelDef(BaseModel):
llama_model: str = Field( llama_model: str = Field(
description="Pointer to the underlying core Llama family model. Each model served by Llama Stack must have a core Llama model.", description="Pointer to the underlying core Llama family model. Each model served by Llama Stack must have a core Llama model.",
) )
metadata: Dict[str, Any] = Field( metadata: Dict[str, str] = Field(
default_factory=dict, default_factory=dict,
description="Any additional metadata for this model", description="Any additional metadata for this model",
) )

View file

@ -4,12 +4,12 @@
# This source code is licensed under the terms described in the LICENSE file in # This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree. # the root directory of this source tree.
from typing import Any, Dict, List, Optional, Protocol, runtime_checkable from typing import Dict, List, Optional, Protocol, runtime_checkable
from llama_models.schema_utils import json_schema_type, webmethod from llama_models.schema_utils import json_schema_type, webmethod
from pydantic import BaseModel, Field
from llama_stack.apis.common.type_system import ParamType from llama_stack.apis.common.type_system import JsonType, ParamType
from pydantic import BaseModel, Field
@json_schema_type @json_schema_type
@ -36,7 +36,7 @@ class LLMAsJudgeContext(BaseModel):
class ScoringFnDef(BaseModel): class ScoringFnDef(BaseModel):
identifier: str identifier: str
description: Optional[str] = None description: Optional[str] = None
metadata: Dict[str, Any] = Field( metadata: Dict[str, str] = Field(
default_factory=dict, default_factory=dict,
description="Any additional metadata for this definition", description="Any additional metadata for this definition",
) )

View file

@ -26,7 +26,7 @@ class ShieldDef(BaseModel):
type: str = Field( type: str = Field(
description="The type of shield this is; the value is one of the ShieldType enum" description="The type of shield this is; the value is one of the ShieldType enum"
) )
params: Dict[str, Any] = Field( params: Dict[str, str] = Field(
default_factory=dict, default_factory=dict,
description="Any additional parameters needed for this shield", description="Any additional parameters needed for this shield",
) )

View file

@ -14,6 +14,7 @@ from llama_stack.apis.memory_banks import * # noqa: F403
from llama_stack.apis.datasets import * # noqa: F403 from llama_stack.apis.datasets import * # noqa: F403
from llama_stack.distribution.datatypes import * # noqa: F403 from llama_stack.distribution.datatypes import * # noqa: F403
import llama_stack.distribution.store as distribution_store
def get_impl_api(p: Any) -> Api: def get_impl_api(p: Any) -> Api:
@ -170,8 +171,7 @@ class CommonRoutingTableImpl(RoutingTable):
if obj.identifier not in self.registry: if obj.identifier not in self.registry:
self.registry[obj.identifier] = [] self.registry[obj.identifier] = []
self.registry[obj.identifier].append(obj) self.registry[obj.identifier].append(obj)
await distribution_store.REGISTRY.register(obj)
# TODO: persist this to a store
class ModelsRoutingTable(CommonRoutingTableImpl, Models): class ModelsRoutingTable(CommonRoutingTableImpl, Models):

View file

@ -22,25 +22,28 @@ import yaml
from fastapi import Body, FastAPI, HTTPException, Request, Response from fastapi import Body, FastAPI, HTTPException, Request, Response
from fastapi.exceptions import RequestValidationError from fastapi.exceptions import RequestValidationError
from fastapi.responses import JSONResponse, StreamingResponse from fastapi.responses import JSONResponse, StreamingResponse
from pydantic import BaseModel, ValidationError
from termcolor import cprint
from typing_extensions import Annotated
from llama_stack.distribution.distribution import ( from llama_stack.distribution.distribution import (
builtin_automatically_routed_apis, builtin_automatically_routed_apis,
get_provider_registry, get_provider_registry,
) )
from llama_stack.distribution.utils.config_dirs import DISTRIBS_BASE_DIR
from llama_stack.providers.utils.telemetry.tracing import ( from llama_stack.providers.utils.telemetry.tracing import (
end_trace, end_trace,
setup_logger, setup_logger,
SpanStatus, SpanStatus,
start_trace, start_trace,
) )
from pydantic import BaseModel, ValidationError
from termcolor import cprint
from typing_extensions import Annotated
from llama_stack.distribution.datatypes import * # noqa: F403 from llama_stack.distribution.datatypes import * # noqa: F403
import llama_stack.distribution.store as distribution_store
from llama_stack.distribution.request_headers import set_request_provider_data from llama_stack.distribution.request_headers import set_request_provider_data
from llama_stack.distribution.resolver import resolve_impls from llama_stack.distribution.resolver import resolve_impls
from llama_stack.providers.utils.kvstore import kvstore_impl, SqliteKVStoreConfig
from .endpoints import get_all_api_endpoints from .endpoints import get_all_api_endpoints
@ -278,6 +281,18 @@ def main(
config = StackRunConfig(**yaml.safe_load(fp)) config = StackRunConfig(**yaml.safe_load(fp))
app = FastAPI() app = FastAPI()
# instantiate kvstore for storing and retrieving distribution metadata
dist_kvstore = asyncio.run(
kvstore_impl(
SqliteKVStoreConfig(
db_path=(
DISTRIBS_BASE_DIR / config.image_name / "kvstore.db"
).as_posix()
)
)
)
distribution_store.REGISTRY = distribution_store.DiskRegistry(dist_kvstore)
impls = asyncio.run(resolve_impls(config, get_provider_registry())) impls = asyncio.run(resolve_impls(config, get_provider_registry()))
if Api.telemetry in impls: if Api.telemetry in impls:

View file

@ -0,0 +1,3 @@
from .registry import DiskRegistry, Registry
REGISTRY: Registry = None

View file

@ -0,0 +1,51 @@
import json
from typing import Protocol
from docs.openapi_generator.strong_typing.deserializer import create_deserializer
from docs.openapi_generator.strong_typing.serialization import object_to_json
from llama_stack.distribution.datatypes import RoutableObjectWithProvider
from llama_stack.providers.utils.kvstore import KVStore
class Registry(Protocol):
async def get(self, identifier: str) -> [RoutableObjectWithProvider]: ...
async def register(self, obj: RoutableObjectWithProvider) -> None: ...
KEY_FORMAT = "distributions:registry:{}"
DESERIALIZER = create_deserializer(RoutableObjectWithProvider)
class DiskRegistry(Registry):
def __init__(self, kvstore: KVStore):
self.kvstore = kvstore
async def get(self, identifier: str) -> [RoutableObjectWithProvider]:
# Get JSON string from kvstore
json_str = await self.kvstore.get(KEY_FORMAT.format(identifier))
if not json_str:
return []
# Parse JSON string into list of objects
objects_data = json.loads(json_str)
return [DESERIALIZER.parse(obj_str) for obj_str in objects_data]
async def register(self, obj: RoutableObjectWithProvider) -> None:
# Get existing objects for this identifier
existing_objects = await self.get(obj.identifier)
# Add new object to list
existing_objects.append(obj)
# Convert all objects to JSON strings and store as JSON array
objects_json = [
object_to_json(existing_object) for existing_object in existing_objects
]
await self.kvstore.set(
KEY_FORMAT.format(obj.identifier), json.dumps(objects_json)
)

View file

@ -81,7 +81,7 @@ class OllamaInferenceAdapter(Inference, ModelsProtocolPrivate):
identifier=llama_model, identifier=llama_model,
llama_model=llama_model, llama_model=llama_model,
metadata={ metadata={
"ollama_model": r["model"], "ollama_model": str(r["model"]),
}, },
) )
) )