mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-07-30 07:39:38 +00:00
persist registered objects with distribution
This commit is contained in:
parent
ac93dd89cf
commit
8fe22230b8
9 changed files with 85 additions and 16 deletions
|
@ -10,10 +10,10 @@ from llama_models.llama3.api.datatypes import URL
|
||||||
|
|
||||||
from llama_models.schema_utils import json_schema_type, webmethod
|
from llama_models.schema_utils import json_schema_type, webmethod
|
||||||
|
|
||||||
from pydantic import BaseModel, Field
|
|
||||||
|
|
||||||
from llama_stack.apis.common.type_system import ParamType
|
from llama_stack.apis.common.type_system import ParamType
|
||||||
|
|
||||||
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
class DatasetDef(BaseModel):
|
class DatasetDef(BaseModel):
|
||||||
|
@ -24,7 +24,7 @@ class DatasetDef(BaseModel):
|
||||||
description="The schema definition for this dataset",
|
description="The schema definition for this dataset",
|
||||||
)
|
)
|
||||||
url: URL
|
url: URL
|
||||||
metadata: Dict[str, Any] = Field(
|
metadata: Dict[str, str] = Field(
|
||||||
default_factory=dict,
|
default_factory=dict,
|
||||||
description="Any additional metadata for this dataset",
|
description="Any additional metadata for this dataset",
|
||||||
)
|
)
|
||||||
|
|
|
@ -17,7 +17,7 @@ class ModelDef(BaseModel):
|
||||||
llama_model: str = Field(
|
llama_model: str = Field(
|
||||||
description="Pointer to the underlying core Llama family model. Each model served by Llama Stack must have a core Llama model.",
|
description="Pointer to the underlying core Llama family model. Each model served by Llama Stack must have a core Llama model.",
|
||||||
)
|
)
|
||||||
metadata: Dict[str, Any] = Field(
|
metadata: Dict[str, str] = Field(
|
||||||
default_factory=dict,
|
default_factory=dict,
|
||||||
description="Any additional metadata for this model",
|
description="Any additional metadata for this model",
|
||||||
)
|
)
|
||||||
|
|
|
@ -4,12 +4,12 @@
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
from typing import Any, Dict, List, Optional, Protocol, runtime_checkable
|
from typing import Dict, List, Optional, Protocol, runtime_checkable
|
||||||
|
|
||||||
from llama_models.schema_utils import json_schema_type, webmethod
|
from llama_models.schema_utils import json_schema_type, webmethod
|
||||||
from pydantic import BaseModel, Field
|
|
||||||
|
|
||||||
from llama_stack.apis.common.type_system import ParamType
|
from llama_stack.apis.common.type_system import JsonType, ParamType
|
||||||
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
|
@ -36,7 +36,7 @@ class LLMAsJudgeContext(BaseModel):
|
||||||
class ScoringFnDef(BaseModel):
|
class ScoringFnDef(BaseModel):
|
||||||
identifier: str
|
identifier: str
|
||||||
description: Optional[str] = None
|
description: Optional[str] = None
|
||||||
metadata: Dict[str, Any] = Field(
|
metadata: Dict[str, str] = Field(
|
||||||
default_factory=dict,
|
default_factory=dict,
|
||||||
description="Any additional metadata for this definition",
|
description="Any additional metadata for this definition",
|
||||||
)
|
)
|
||||||
|
|
|
@ -26,7 +26,7 @@ class ShieldDef(BaseModel):
|
||||||
type: str = Field(
|
type: str = Field(
|
||||||
description="The type of shield this is; the value is one of the ShieldType enum"
|
description="The type of shield this is; the value is one of the ShieldType enum"
|
||||||
)
|
)
|
||||||
params: Dict[str, Any] = Field(
|
params: Dict[str, str] = Field(
|
||||||
default_factory=dict,
|
default_factory=dict,
|
||||||
description="Any additional parameters needed for this shield",
|
description="Any additional parameters needed for this shield",
|
||||||
)
|
)
|
||||||
|
|
|
@ -14,6 +14,7 @@ from llama_stack.apis.memory_banks import * # noqa: F403
|
||||||
from llama_stack.apis.datasets import * # noqa: F403
|
from llama_stack.apis.datasets import * # noqa: F403
|
||||||
|
|
||||||
from llama_stack.distribution.datatypes import * # noqa: F403
|
from llama_stack.distribution.datatypes import * # noqa: F403
|
||||||
|
import llama_stack.distribution.store as distribution_store
|
||||||
|
|
||||||
|
|
||||||
def get_impl_api(p: Any) -> Api:
|
def get_impl_api(p: Any) -> Api:
|
||||||
|
@ -170,8 +171,7 @@ class CommonRoutingTableImpl(RoutingTable):
|
||||||
if obj.identifier not in self.registry:
|
if obj.identifier not in self.registry:
|
||||||
self.registry[obj.identifier] = []
|
self.registry[obj.identifier] = []
|
||||||
self.registry[obj.identifier].append(obj)
|
self.registry[obj.identifier].append(obj)
|
||||||
|
await distribution_store.REGISTRY.register(obj)
|
||||||
# TODO: persist this to a store
|
|
||||||
|
|
||||||
|
|
||||||
class ModelsRoutingTable(CommonRoutingTableImpl, Models):
|
class ModelsRoutingTable(CommonRoutingTableImpl, Models):
|
||||||
|
|
|
@ -22,25 +22,28 @@ import yaml
|
||||||
from fastapi import Body, FastAPI, HTTPException, Request, Response
|
from fastapi import Body, FastAPI, HTTPException, Request, Response
|
||||||
from fastapi.exceptions import RequestValidationError
|
from fastapi.exceptions import RequestValidationError
|
||||||
from fastapi.responses import JSONResponse, StreamingResponse
|
from fastapi.responses import JSONResponse, StreamingResponse
|
||||||
from pydantic import BaseModel, ValidationError
|
|
||||||
from termcolor import cprint
|
|
||||||
from typing_extensions import Annotated
|
|
||||||
|
|
||||||
from llama_stack.distribution.distribution import (
|
from llama_stack.distribution.distribution import (
|
||||||
builtin_automatically_routed_apis,
|
builtin_automatically_routed_apis,
|
||||||
get_provider_registry,
|
get_provider_registry,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
from llama_stack.distribution.utils.config_dirs import DISTRIBS_BASE_DIR
|
||||||
|
|
||||||
from llama_stack.providers.utils.telemetry.tracing import (
|
from llama_stack.providers.utils.telemetry.tracing import (
|
||||||
end_trace,
|
end_trace,
|
||||||
setup_logger,
|
setup_logger,
|
||||||
SpanStatus,
|
SpanStatus,
|
||||||
start_trace,
|
start_trace,
|
||||||
)
|
)
|
||||||
|
from pydantic import BaseModel, ValidationError
|
||||||
|
from termcolor import cprint
|
||||||
|
from typing_extensions import Annotated
|
||||||
from llama_stack.distribution.datatypes import * # noqa: F403
|
from llama_stack.distribution.datatypes import * # noqa: F403
|
||||||
|
import llama_stack.distribution.store as distribution_store
|
||||||
from llama_stack.distribution.request_headers import set_request_provider_data
|
from llama_stack.distribution.request_headers import set_request_provider_data
|
||||||
from llama_stack.distribution.resolver import resolve_impls
|
from llama_stack.distribution.resolver import resolve_impls
|
||||||
|
from llama_stack.providers.utils.kvstore import kvstore_impl, SqliteKVStoreConfig
|
||||||
|
|
||||||
from .endpoints import get_all_api_endpoints
|
from .endpoints import get_all_api_endpoints
|
||||||
|
|
||||||
|
@ -278,6 +281,18 @@ def main(
|
||||||
config = StackRunConfig(**yaml.safe_load(fp))
|
config = StackRunConfig(**yaml.safe_load(fp))
|
||||||
|
|
||||||
app = FastAPI()
|
app = FastAPI()
|
||||||
|
# instantiate kvstore for storing and retrieving distribution metadata
|
||||||
|
dist_kvstore = asyncio.run(
|
||||||
|
kvstore_impl(
|
||||||
|
SqliteKVStoreConfig(
|
||||||
|
db_path=(
|
||||||
|
DISTRIBS_BASE_DIR / config.image_name / "kvstore.db"
|
||||||
|
).as_posix()
|
||||||
|
)
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
distribution_store.REGISTRY = distribution_store.DiskRegistry(dist_kvstore)
|
||||||
|
|
||||||
impls = asyncio.run(resolve_impls(config, get_provider_registry()))
|
impls = asyncio.run(resolve_impls(config, get_provider_registry()))
|
||||||
if Api.telemetry in impls:
|
if Api.telemetry in impls:
|
||||||
|
|
3
llama_stack/distribution/store/__init__.py
Normal file
3
llama_stack/distribution/store/__init__.py
Normal file
|
@ -0,0 +1,3 @@
|
||||||
|
from .registry import DiskRegistry, Registry
|
||||||
|
|
||||||
|
REGISTRY: Registry = None
|
51
llama_stack/distribution/store/registry.py
Normal file
51
llama_stack/distribution/store/registry.py
Normal file
|
@ -0,0 +1,51 @@
|
||||||
|
import json
|
||||||
|
|
||||||
|
from typing import Protocol
|
||||||
|
|
||||||
|
from docs.openapi_generator.strong_typing.deserializer import create_deserializer
|
||||||
|
|
||||||
|
from docs.openapi_generator.strong_typing.serialization import object_to_json
|
||||||
|
|
||||||
|
from llama_stack.distribution.datatypes import RoutableObjectWithProvider
|
||||||
|
|
||||||
|
from llama_stack.providers.utils.kvstore import KVStore
|
||||||
|
|
||||||
|
|
||||||
|
class Registry(Protocol):
|
||||||
|
async def get(self, identifier: str) -> [RoutableObjectWithProvider]: ...
|
||||||
|
async def register(self, obj: RoutableObjectWithProvider) -> None: ...
|
||||||
|
|
||||||
|
|
||||||
|
KEY_FORMAT = "distributions:registry:{}"
|
||||||
|
DESERIALIZER = create_deserializer(RoutableObjectWithProvider)
|
||||||
|
|
||||||
|
|
||||||
|
class DiskRegistry(Registry):
|
||||||
|
def __init__(self, kvstore: KVStore):
|
||||||
|
self.kvstore = kvstore
|
||||||
|
|
||||||
|
async def get(self, identifier: str) -> [RoutableObjectWithProvider]:
|
||||||
|
# Get JSON string from kvstore
|
||||||
|
json_str = await self.kvstore.get(KEY_FORMAT.format(identifier))
|
||||||
|
if not json_str:
|
||||||
|
return []
|
||||||
|
|
||||||
|
# Parse JSON string into list of objects
|
||||||
|
objects_data = json.loads(json_str)
|
||||||
|
|
||||||
|
return [DESERIALIZER.parse(obj_str) for obj_str in objects_data]
|
||||||
|
|
||||||
|
async def register(self, obj: RoutableObjectWithProvider) -> None:
|
||||||
|
# Get existing objects for this identifier
|
||||||
|
existing_objects = await self.get(obj.identifier)
|
||||||
|
|
||||||
|
# Add new object to list
|
||||||
|
existing_objects.append(obj)
|
||||||
|
|
||||||
|
# Convert all objects to JSON strings and store as JSON array
|
||||||
|
objects_json = [
|
||||||
|
object_to_json(existing_object) for existing_object in existing_objects
|
||||||
|
]
|
||||||
|
await self.kvstore.set(
|
||||||
|
KEY_FORMAT.format(obj.identifier), json.dumps(objects_json)
|
||||||
|
)
|
|
@ -81,7 +81,7 @@ class OllamaInferenceAdapter(Inference, ModelsProtocolPrivate):
|
||||||
identifier=llama_model,
|
identifier=llama_model,
|
||||||
llama_model=llama_model,
|
llama_model=llama_model,
|
||||||
metadata={
|
metadata={
|
||||||
"ollama_model": r["model"],
|
"ollama_model": str(r["model"]),
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue