mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-07-29 15:23:51 +00:00
persist registered objects with distribution
This commit is contained in:
parent
ac93dd89cf
commit
8fe22230b8
9 changed files with 85 additions and 16 deletions
|
@ -10,10 +10,10 @@ from llama_models.llama3.api.datatypes import URL
|
|||
|
||||
from llama_models.schema_utils import json_schema_type, webmethod
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from llama_stack.apis.common.type_system import ParamType
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class DatasetDef(BaseModel):
|
||||
|
@ -24,7 +24,7 @@ class DatasetDef(BaseModel):
|
|||
description="The schema definition for this dataset",
|
||||
)
|
||||
url: URL
|
||||
metadata: Dict[str, Any] = Field(
|
||||
metadata: Dict[str, str] = Field(
|
||||
default_factory=dict,
|
||||
description="Any additional metadata for this dataset",
|
||||
)
|
||||
|
|
|
@ -17,7 +17,7 @@ class ModelDef(BaseModel):
|
|||
llama_model: str = Field(
|
||||
description="Pointer to the underlying core Llama family model. Each model served by Llama Stack must have a core Llama model.",
|
||||
)
|
||||
metadata: Dict[str, Any] = Field(
|
||||
metadata: Dict[str, str] = Field(
|
||||
default_factory=dict,
|
||||
description="Any additional metadata for this model",
|
||||
)
|
||||
|
|
|
@ -4,12 +4,12 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import Any, Dict, List, Optional, Protocol, runtime_checkable
|
||||
from typing import Dict, List, Optional, Protocol, runtime_checkable
|
||||
|
||||
from llama_models.schema_utils import json_schema_type, webmethod
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from llama_stack.apis.common.type_system import ParamType
|
||||
from llama_stack.apis.common.type_system import JsonType, ParamType
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
@json_schema_type
|
||||
|
@ -36,7 +36,7 @@ class LLMAsJudgeContext(BaseModel):
|
|||
class ScoringFnDef(BaseModel):
|
||||
identifier: str
|
||||
description: Optional[str] = None
|
||||
metadata: Dict[str, Any] = Field(
|
||||
metadata: Dict[str, str] = Field(
|
||||
default_factory=dict,
|
||||
description="Any additional metadata for this definition",
|
||||
)
|
||||
|
|
|
@ -26,7 +26,7 @@ class ShieldDef(BaseModel):
|
|||
type: str = Field(
|
||||
description="The type of shield this is; the value is one of the ShieldType enum"
|
||||
)
|
||||
params: Dict[str, Any] = Field(
|
||||
params: Dict[str, str] = Field(
|
||||
default_factory=dict,
|
||||
description="Any additional parameters needed for this shield",
|
||||
)
|
||||
|
|
|
@ -14,6 +14,7 @@ from llama_stack.apis.memory_banks import * # noqa: F403
|
|||
from llama_stack.apis.datasets import * # noqa: F403
|
||||
|
||||
from llama_stack.distribution.datatypes import * # noqa: F403
|
||||
import llama_stack.distribution.store as distribution_store
|
||||
|
||||
|
||||
def get_impl_api(p: Any) -> Api:
|
||||
|
@ -170,8 +171,7 @@ class CommonRoutingTableImpl(RoutingTable):
|
|||
if obj.identifier not in self.registry:
|
||||
self.registry[obj.identifier] = []
|
||||
self.registry[obj.identifier].append(obj)
|
||||
|
||||
# TODO: persist this to a store
|
||||
await distribution_store.REGISTRY.register(obj)
|
||||
|
||||
|
||||
class ModelsRoutingTable(CommonRoutingTableImpl, Models):
|
||||
|
|
|
@ -22,25 +22,28 @@ import yaml
|
|||
from fastapi import Body, FastAPI, HTTPException, Request, Response
|
||||
from fastapi.exceptions import RequestValidationError
|
||||
from fastapi.responses import JSONResponse, StreamingResponse
|
||||
from pydantic import BaseModel, ValidationError
|
||||
from termcolor import cprint
|
||||
from typing_extensions import Annotated
|
||||
|
||||
from llama_stack.distribution.distribution import (
|
||||
builtin_automatically_routed_apis,
|
||||
get_provider_registry,
|
||||
)
|
||||
|
||||
from llama_stack.distribution.utils.config_dirs import DISTRIBS_BASE_DIR
|
||||
|
||||
from llama_stack.providers.utils.telemetry.tracing import (
|
||||
end_trace,
|
||||
setup_logger,
|
||||
SpanStatus,
|
||||
start_trace,
|
||||
)
|
||||
from pydantic import BaseModel, ValidationError
|
||||
from termcolor import cprint
|
||||
from typing_extensions import Annotated
|
||||
from llama_stack.distribution.datatypes import * # noqa: F403
|
||||
|
||||
import llama_stack.distribution.store as distribution_store
|
||||
from llama_stack.distribution.request_headers import set_request_provider_data
|
||||
from llama_stack.distribution.resolver import resolve_impls
|
||||
from llama_stack.providers.utils.kvstore import kvstore_impl, SqliteKVStoreConfig
|
||||
|
||||
from .endpoints import get_all_api_endpoints
|
||||
|
||||
|
@ -278,6 +281,18 @@ def main(
|
|||
config = StackRunConfig(**yaml.safe_load(fp))
|
||||
|
||||
app = FastAPI()
|
||||
# instantiate kvstore for storing and retrieving distribution metadata
|
||||
dist_kvstore = asyncio.run(
|
||||
kvstore_impl(
|
||||
SqliteKVStoreConfig(
|
||||
db_path=(
|
||||
DISTRIBS_BASE_DIR / config.image_name / "kvstore.db"
|
||||
).as_posix()
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
distribution_store.REGISTRY = distribution_store.DiskRegistry(dist_kvstore)
|
||||
|
||||
impls = asyncio.run(resolve_impls(config, get_provider_registry()))
|
||||
if Api.telemetry in impls:
|
||||
|
|
3
llama_stack/distribution/store/__init__.py
Normal file
3
llama_stack/distribution/store/__init__.py
Normal file
|
@ -0,0 +1,3 @@
|
|||
from .registry import DiskRegistry, Registry
|
||||
|
||||
REGISTRY: Registry = None
|
51
llama_stack/distribution/store/registry.py
Normal file
51
llama_stack/distribution/store/registry.py
Normal file
|
@ -0,0 +1,51 @@
|
|||
import json
|
||||
|
||||
from typing import Protocol
|
||||
|
||||
from docs.openapi_generator.strong_typing.deserializer import create_deserializer
|
||||
|
||||
from docs.openapi_generator.strong_typing.serialization import object_to_json
|
||||
|
||||
from llama_stack.distribution.datatypes import RoutableObjectWithProvider
|
||||
|
||||
from llama_stack.providers.utils.kvstore import KVStore
|
||||
|
||||
|
||||
class Registry(Protocol):
|
||||
async def get(self, identifier: str) -> [RoutableObjectWithProvider]: ...
|
||||
async def register(self, obj: RoutableObjectWithProvider) -> None: ...
|
||||
|
||||
|
||||
KEY_FORMAT = "distributions:registry:{}"
|
||||
DESERIALIZER = create_deserializer(RoutableObjectWithProvider)
|
||||
|
||||
|
||||
class DiskRegistry(Registry):
|
||||
def __init__(self, kvstore: KVStore):
|
||||
self.kvstore = kvstore
|
||||
|
||||
async def get(self, identifier: str) -> [RoutableObjectWithProvider]:
|
||||
# Get JSON string from kvstore
|
||||
json_str = await self.kvstore.get(KEY_FORMAT.format(identifier))
|
||||
if not json_str:
|
||||
return []
|
||||
|
||||
# Parse JSON string into list of objects
|
||||
objects_data = json.loads(json_str)
|
||||
|
||||
return [DESERIALIZER.parse(obj_str) for obj_str in objects_data]
|
||||
|
||||
async def register(self, obj: RoutableObjectWithProvider) -> None:
|
||||
# Get existing objects for this identifier
|
||||
existing_objects = await self.get(obj.identifier)
|
||||
|
||||
# Add new object to list
|
||||
existing_objects.append(obj)
|
||||
|
||||
# Convert all objects to JSON strings and store as JSON array
|
||||
objects_json = [
|
||||
object_to_json(existing_object) for existing_object in existing_objects
|
||||
]
|
||||
await self.kvstore.set(
|
||||
KEY_FORMAT.format(obj.identifier), json.dumps(objects_json)
|
||||
)
|
|
@ -81,7 +81,7 @@ class OllamaInferenceAdapter(Inference, ModelsProtocolPrivate):
|
|||
identifier=llama_model,
|
||||
llama_model=llama_model,
|
||||
metadata={
|
||||
"ollama_model": r["model"],
|
||||
"ollama_model": str(r["model"]),
|
||||
},
|
||||
)
|
||||
)
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue