From a3064ca6fcf1603430535453b9caad13f390b88b Mon Sep 17 00:00:00 2001 From: Dinesh Yeduguru Date: Fri, 1 Nov 2024 14:32:50 -0700 Subject: [PATCH] add back Any --- llama_stack/apis/datasets/datasets.py | 8 ++++---- llama_stack/apis/models/models.py | 4 ++-- llama_stack/apis/scoring_functions/scoring_functions.py | 6 +++--- llama_stack/apis/shields/shields.py | 4 ++-- llama_stack/distribution/store/registry.py | 4 ++++ 5 files changed, 15 insertions(+), 11 deletions(-) diff --git a/llama_stack/apis/datasets/datasets.py b/llama_stack/apis/datasets/datasets.py index 4c5540b72..f5efc7cbc 100644 --- a/llama_stack/apis/datasets/datasets.py +++ b/llama_stack/apis/datasets/datasets.py @@ -4,16 +4,16 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -from typing import Dict, List, Literal, Optional, Protocol +from typing import Any, Dict, List, Literal, Optional, Protocol from llama_models.llama3.api.datatypes import URL from llama_models.schema_utils import json_schema_type, webmethod -from pydantic import BaseModel, Field - from llama_stack.apis.common.type_system import ParamType +from pydantic import BaseModel, Field + @json_schema_type class DatasetDef(BaseModel): @@ -24,7 +24,7 @@ class DatasetDef(BaseModel): description="The schema definition for this dataset", ) url: URL - metadata: Dict[str, str] = Field( + metadata: Dict[str, Any] = Field( default_factory=dict, description="Any additional metadata for this dataset", ) diff --git a/llama_stack/apis/models/models.py b/llama_stack/apis/models/models.py index 5c75a1ca0..ffb3b022e 100644 --- a/llama_stack/apis/models/models.py +++ b/llama_stack/apis/models/models.py @@ -4,7 +4,7 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -from typing import Dict, List, Literal, Optional, Protocol, runtime_checkable +from typing import Any, Dict, List, Literal, Optional, Protocol, runtime_checkable from llama_models.schema_utils import json_schema_type, webmethod from pydantic import BaseModel, Field @@ -17,7 +17,7 @@ class ModelDef(BaseModel): llama_model: str = Field( description="Pointer to the underlying core Llama family model. Each model served by Llama Stack must have a core Llama model.", ) - metadata: Dict[str, str] = Field( + metadata: Dict[str, Any] = Field( default_factory=dict, description="Any additional metadata for this model", ) diff --git a/llama_stack/apis/scoring_functions/scoring_functions.py b/llama_stack/apis/scoring_functions/scoring_functions.py index bf612bda0..d5c65e738 100644 --- a/llama_stack/apis/scoring_functions/scoring_functions.py +++ b/llama_stack/apis/scoring_functions/scoring_functions.py @@ -4,12 +4,12 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -from typing import Dict, List, Literal, Optional, Protocol, runtime_checkable +from typing import Any, Dict, List, Literal, Optional, Protocol, runtime_checkable from llama_models.schema_utils import json_schema_type, webmethod -from pydantic import BaseModel, Field from llama_stack.apis.common.type_system import ParamType +from pydantic import BaseModel, Field @json_schema_type @@ -36,7 +36,7 @@ class LLMAsJudgeContext(BaseModel): class ScoringFnDef(BaseModel): identifier: str description: Optional[str] = None - metadata: Dict[str, str] = Field( + metadata: Dict[str, Any] = Field( default_factory=dict, description="Any additional metadata for this definition", ) diff --git a/llama_stack/apis/shields/shields.py b/llama_stack/apis/shields/shields.py index 826e7c30e..0d1177f5a 100644 --- a/llama_stack/apis/shields/shields.py +++ b/llama_stack/apis/shields/shields.py @@ -5,7 +5,7 @@ # the root directory of this source tree. from enum import Enum -from typing import Dict, List, Literal, Optional, Protocol, runtime_checkable +from typing import Any, Dict, List, Literal, Optional, Protocol, runtime_checkable from llama_models.schema_utils import json_schema_type, webmethod from pydantic import BaseModel, Field @@ -26,7 +26,7 @@ class ShieldDef(BaseModel): type: str = Field( description="The type of shield this is; the value is one of the ShieldType enum" ) - params: Dict[str, str] = Field( + params: Dict[str, Any] = Field( default_factory=dict, description="Any additional parameters needed for this shield", ) diff --git a/llama_stack/distribution/store/registry.py b/llama_stack/distribution/store/registry.py index 95a393aa5..72ce20245 100644 --- a/llama_stack/distribution/store/registry.py +++ b/llama_stack/distribution/store/registry.py @@ -47,6 +47,10 @@ class DiskRegistry(Registry): # TODO: make it thread safe using CAS async def register(self, obj: RoutableObjectWithProvider) -> None: existing_objects = await self.get(obj.identifier) + # dont register if the object's providerid already exists + for eobj in existing_objects: + if eobj.provider_id == obj.provider_id: + return existing_objects.append(obj)