rebase on top of registry

This commit is contained in:
Xi Yan 2024-10-08 23:41:03 -07:00
commit 6abef716dd
107 changed files with 4813 additions and 3587 deletions

View file

@ -261,7 +261,7 @@ class Session(BaseModel):
turns: List[Turn]
started_at: datetime
memory_bank: Optional[MemoryBank] = None
memory_bank: Optional[MemoryBankDef] = None
class AgentConfigCommon(BaseModel):
@ -411,8 +411,10 @@ class Agents(Protocol):
agent_config: AgentConfig,
) -> AgentCreateResponse: ...
# This method is not `async def` because it can result in either an
# `AsyncGenerator` or a `AgentTurnCreateResponse` depending on the value of `stream`.
@webmethod(route="/agents/turn/create")
async def create_agent_turn(
def create_agent_turn(
self,
agent_id: str,
session_id: str,

View file

@ -7,7 +7,7 @@
import asyncio
import json
import os
from typing import AsyncGenerator
from typing import AsyncGenerator, Optional
import fire
import httpx
@ -67,9 +67,17 @@ class AgentsClient(Agents):
response.raise_for_status()
return AgentSessionCreateResponse(**response.json())
async def create_agent_turn(
def create_agent_turn(
self,
request: AgentTurnCreateRequest,
) -> AsyncGenerator:
if request.stream:
return self._stream_agent_turn(request)
else:
return self._nonstream_agent_turn(request)
async def _stream_agent_turn(
self, request: AgentTurnCreateRequest
) -> AsyncGenerator:
async with httpx.AsyncClient() as client:
async with client.stream(
@ -93,6 +101,9 @@ class AgentsClient(Agents):
print(data)
print(f"Error with parsing or validation: {e}")
async def _nonstream_agent_turn(self, request: AgentTurnCreateRequest):
raise NotImplementedError("Non-streaming not implemented yet")
async def _run_agent(
api, model, tool_definitions, tool_prompt_format, user_prompts, attachments=None
@ -132,8 +143,7 @@ async def _run_agent(
log.print()
async def run_llama_3_1(host: str, port: int):
model = "Llama3.1-8B-Instruct"
async def run_llama_3_1(host: str, port: int, model: str = "Llama3.1-8B-Instruct"):
api = AgentsClient(f"http://{host}:{port}")
tool_definitions = [
@ -173,8 +183,7 @@ async def run_llama_3_1(host: str, port: int):
await _run_agent(api, model, tool_definitions, ToolPromptFormat.json, user_prompts)
async def run_llama_3_2_rag(host: str, port: int):
model = "Llama3.2-3B-Instruct"
async def run_llama_3_2_rag(host: str, port: int, model: str = "Llama3.2-3B-Instruct"):
api = AgentsClient(f"http://{host}:{port}")
urls = [
@ -215,8 +224,7 @@ async def run_llama_3_2_rag(host: str, port: int):
)
async def run_llama_3_2(host: str, port: int):
model = "Llama3.2-3B-Instruct"
async def run_llama_3_2(host: str, port: int, model: str = "Llama3.2-3B-Instruct"):
api = AgentsClient(f"http://{host}:{port}")
# zero shot tools for llama3.2 text models
@ -262,7 +270,7 @@ async def run_llama_3_2(host: str, port: int):
)
def main(host: str, port: int, run_type: str):
def main(host: str, port: int, run_type: str, model: Optional[str] = None):
assert run_type in [
"tools_llama_3_1",
"tools_llama_3_2",
@ -274,7 +282,10 @@ def main(host: str, port: int, run_type: str):
"tools_llama_3_2": run_llama_3_2,
"rag_llama_3_2": run_llama_3_2_rag,
}
asyncio.run(fn[run_type](host, port))
args = [host, port]
if model is not None:
args.append(model)
asyncio.run(fn[run_type](*args))
if __name__ == "__main__":

View file

@ -42,10 +42,10 @@ class InferenceClient(Inference):
async def shutdown(self) -> None:
pass
async def completion(self, request: CompletionRequest) -> AsyncGenerator:
def completion(self, request: CompletionRequest) -> AsyncGenerator:
raise NotImplementedError()
async def chat_completion(
def chat_completion(
self,
model: str,
messages: List[Message],
@ -66,6 +66,29 @@ class InferenceClient(Inference):
stream=stream,
logprobs=logprobs,
)
if stream:
return self._stream_chat_completion(request)
else:
return self._nonstream_chat_completion(request)
async def _nonstream_chat_completion(
self, request: ChatCompletionRequest
) -> ChatCompletionResponse:
async with httpx.AsyncClient() as client:
response = await client.post(
f"{self.base_url}/inference/chat_completion",
json=encodable_dict(request),
headers={"Content-Type": "application/json"},
timeout=20,
)
response.raise_for_status()
j = response.json()
return ChatCompletionResponse(**j)
async def _stream_chat_completion(
self, request: ChatCompletionRequest
) -> AsyncGenerator:
async with httpx.AsyncClient() as client:
async with client.stream(
"POST",
@ -77,7 +100,8 @@ class InferenceClient(Inference):
if response.status_code != 200:
content = await response.aread()
cprint(
f"Error: HTTP {response.status_code} {content.decode()}", "red"
f"Error: HTTP {response.status_code} {content.decode()}",
"red",
)
return
@ -85,40 +109,59 @@ class InferenceClient(Inference):
if line.startswith("data:"):
data = line[len("data: ") :]
try:
if request.stream:
if "error" in data:
cprint(data, "red")
continue
if "error" in data:
cprint(data, "red")
continue
yield ChatCompletionResponseStreamChunk(
**json.loads(data)
)
else:
yield ChatCompletionResponse(**json.loads(data))
yield ChatCompletionResponseStreamChunk(**json.loads(data))
except Exception as e:
print(data)
print(f"Error with parsing or validation: {e}")
async def run_main(host: str, port: int, stream: bool):
async def run_main(
host: str, port: int, stream: bool, model: Optional[str], logprobs: bool
):
client = InferenceClient(f"http://{host}:{port}")
if not model:
model = "Llama3.1-8B-Instruct"
message = UserMessage(
content="hello world, write me a 2 sentence poem about the moon"
)
cprint(f"User>{message.content}", "green")
if logprobs:
logprobs_config = LogProbConfig(
top_k=1,
)
else:
logprobs_config = None
iterator = client.chat_completion(
model="Llama3.1-8B-Instruct",
model=model,
messages=[message],
stream=stream,
logprobs=logprobs_config,
)
async for log in EventLogger().log(iterator):
log.print()
if logprobs:
async for chunk in iterator:
cprint(f"Response: {chunk}", "red")
else:
async for log in EventLogger().log(iterator):
log.print()
async def run_mm_main(host: str, port: int, stream: bool, path: str):
async def run_mm_main(
host: str, port: int, stream: bool, path: Optional[str], model: Optional[str]
):
client = InferenceClient(f"http://{host}:{port}")
if not model:
model = "Llama3.2-11B-Vision-Instruct"
message = UserMessage(
content=[
ImageMedia(image=URL(uri=f"file://{path}")),
@ -127,7 +170,7 @@ async def run_mm_main(host: str, port: int, stream: bool, path: str):
)
cprint(f"User>{message.content}", "green")
iterator = client.chat_completion(
model="Llama3.2-11B-Vision-Instruct",
model=model,
messages=[message],
stream=stream,
)
@ -135,11 +178,19 @@ async def run_mm_main(host: str, port: int, stream: bool, path: str):
log.print()
def main(host: str, port: int, stream: bool = True, mm: bool = False, file: str = None):
def main(
host: str,
port: int,
stream: bool = True,
mm: bool = False,
logprobs: bool = False,
file: Optional[str] = None,
model: Optional[str] = None,
):
if mm:
asyncio.run(run_mm_main(host, port, stream, file))
asyncio.run(run_mm_main(host, port, stream, file, model))
else:
asyncio.run(run_main(host, port, stream))
asyncio.run(run_main(host, port, stream, model, logprobs))
if __name__ == "__main__":

View file

@ -14,6 +14,7 @@ from pydantic import BaseModel, Field
from typing_extensions import Annotated
from llama_models.llama3.api.datatypes import * # noqa: F403
from llama_stack.apis.models import * # noqa: F403
class LogProbConfig(BaseModel):
@ -172,9 +173,17 @@ class EmbeddingsResponse(BaseModel):
embeddings: List[List[float]]
class ModelStore(Protocol):
def get_model(self, identifier: str) -> ModelDef: ...
class Inference(Protocol):
model_store: ModelStore
# This method is not `async def` because it can result in either an
# `AsyncGenerator` or a `CompletionResponse` depending on the value of `stream`.
@webmethod(route="/inference/completion")
async def completion(
def completion(
self,
model: str,
content: InterleavedTextMedia,
@ -183,8 +192,10 @@ class Inference(Protocol):
logprobs: Optional[LogProbConfig] = None,
) -> Union[CompletionResponse, CompletionResponseStreamChunk]: ...
# This method is not `async def` because it can result in either an
# `AsyncGenerator` or a `ChatCompletionResponse` depending on the value of `stream`.
@webmethod(route="/inference/chat_completion")
async def chat_completion(
def chat_completion(
self,
model: str,
messages: List[Message],
@ -203,3 +214,6 @@ class Inference(Protocol):
model: str,
contents: List[InterleavedTextMedia],
) -> EmbeddingsResponse: ...
@webmethod(route="/inference/register_model")
async def register_model(self, model: ModelDef) -> None: ...

View file

@ -12,15 +12,15 @@ from pydantic import BaseModel
@json_schema_type
class ProviderInfo(BaseModel):
provider_id: str
provider_type: str
description: str
@json_schema_type
class RouteInfo(BaseModel):
route: str
method: str
providers: List[str]
provider_types: List[str]
@json_schema_type

View file

@ -13,11 +13,11 @@ from typing import Any, Dict, List, Optional
import fire
import httpx
from termcolor import cprint
from llama_stack.distribution.datatypes import RemoteProviderConfig
from llama_stack.apis.memory import * # noqa: F403
from llama_stack.apis.memory_banks.client import MemoryBanksClient
from llama_stack.providers.utils.memory.file_utils import data_url_from_file
@ -35,44 +35,16 @@ class MemoryClient(Memory):
async def shutdown(self) -> None:
pass
async def get_memory_bank(self, bank_id: str) -> Optional[MemoryBank]:
async def register_memory_bank(self, memory_bank: MemoryBankDef) -> None:
async with httpx.AsyncClient() as client:
r = await client.get(
f"{self.base_url}/memory/get",
params={
"bank_id": bank_id,
},
headers={"Content-Type": "application/json"},
timeout=20,
)
r.raise_for_status()
d = r.json()
if not d:
return None
return MemoryBank(**d)
async def create_memory_bank(
self,
name: str,
config: MemoryBankConfig,
url: Optional[URL] = None,
) -> MemoryBank:
async with httpx.AsyncClient() as client:
r = await client.post(
f"{self.base_url}/memory/create",
response = await client.post(
f"{self.base_url}/memory/register_memory_bank",
json={
"name": name,
"config": config.dict(),
"url": url,
"memory_bank": json.loads(memory_bank.json()),
},
headers={"Content-Type": "application/json"},
timeout=20,
)
r.raise_for_status()
d = r.json()
if not d:
return None
return MemoryBank(**d)
response.raise_for_status()
async def insert_documents(
self,
@ -114,22 +86,20 @@ class MemoryClient(Memory):
async def run_main(host: str, port: int, stream: bool):
client = MemoryClient(f"http://{host}:{port}")
banks_client = MemoryBanksClient(f"http://{host}:{port}")
# create a memory bank
bank = await client.create_memory_bank(
name="test_bank",
config=VectorMemoryBankConfig(
bank_id="test_bank",
embedding_model="all-MiniLM-L6-v2",
chunk_size_in_tokens=512,
overlap_size_in_tokens=64,
),
bank = VectorMemoryBankDef(
identifier="test_bank",
provider_id="",
embedding_model="all-MiniLM-L6-v2",
chunk_size_in_tokens=512,
overlap_size_in_tokens=64,
)
cprint(json.dumps(bank.dict(), indent=4), "green")
await client.register_memory_bank(bank)
retrieved_bank = await client.get_memory_bank(bank.bank_id)
retrieved_bank = await banks_client.get_memory_bank(bank.identifier)
assert retrieved_bank is not None
assert retrieved_bank.config.embedding_model == "all-MiniLM-L6-v2"
assert retrieved_bank.embedding_model == "all-MiniLM-L6-v2"
urls = [
"memory_optimizations.rst",
@ -162,13 +132,13 @@ async def run_main(host: str, port: int, stream: bool):
# insert some documents
await client.insert_documents(
bank_id=bank.bank_id,
bank_id=bank.identifier,
documents=documents,
)
# query the documents
response = await client.query_documents(
bank_id=bank.bank_id,
bank_id=bank.identifier,
query=[
"How do I use Lora?",
],
@ -178,7 +148,7 @@ async def run_main(host: str, port: int, stream: bool):
print(f"Chunk:\n========\n{chunk}\n========\n")
response = await client.query_documents(
bank_id=bank.bank_id,
bank_id=bank.identifier,
query=[
"Tell me more about llama3 and torchtune",
],

View file

@ -13,9 +13,9 @@ from typing import List, Optional, Protocol
from llama_models.schema_utils import json_schema_type, webmethod
from pydantic import BaseModel, Field
from typing_extensions import Annotated
from llama_models.llama3.api.datatypes import * # noqa: F403
from llama_stack.apis.memory_banks import * # noqa: F403
@json_schema_type
@ -26,44 +26,6 @@ class MemoryBankDocument(BaseModel):
metadata: Dict[str, Any] = Field(default_factory=dict)
@json_schema_type
class MemoryBankType(Enum):
vector = "vector"
keyvalue = "keyvalue"
keyword = "keyword"
graph = "graph"
class VectorMemoryBankConfig(BaseModel):
type: Literal[MemoryBankType.vector.value] = MemoryBankType.vector.value
embedding_model: str
chunk_size_in_tokens: int
overlap_size_in_tokens: Optional[int] = None
class KeyValueMemoryBankConfig(BaseModel):
type: Literal[MemoryBankType.keyvalue.value] = MemoryBankType.keyvalue.value
class KeywordMemoryBankConfig(BaseModel):
type: Literal[MemoryBankType.keyword.value] = MemoryBankType.keyword.value
class GraphMemoryBankConfig(BaseModel):
type: Literal[MemoryBankType.graph.value] = MemoryBankType.graph.value
MemoryBankConfig = Annotated[
Union[
VectorMemoryBankConfig,
KeyValueMemoryBankConfig,
KeywordMemoryBankConfig,
GraphMemoryBankConfig,
],
Field(discriminator="type"),
]
class Chunk(BaseModel):
content: InterleavedTextMedia
token_count: int
@ -76,45 +38,12 @@ class QueryDocumentsResponse(BaseModel):
scores: List[float]
@json_schema_type
class QueryAPI(Protocol):
@webmethod(route="/query_documents")
def query_documents(
self,
query: InterleavedTextMedia,
params: Optional[Dict[str, Any]] = None,
) -> QueryDocumentsResponse: ...
@json_schema_type
class MemoryBank(BaseModel):
bank_id: str
name: str
config: MemoryBankConfig
# if there's a pre-existing (reachable-from-distribution) store which supports QueryAPI
url: Optional[URL] = None
class MemoryBankStore(Protocol):
def get_memory_bank(self, bank_id: str) -> Optional[MemoryBankDef]: ...
class Memory(Protocol):
@webmethod(route="/memory/create")
async def create_memory_bank(
self,
name: str,
config: MemoryBankConfig,
url: Optional[URL] = None,
) -> MemoryBank: ...
@webmethod(route="/memory/list", method="GET")
async def list_memory_banks(self) -> List[MemoryBank]: ...
@webmethod(route="/memory/get", method="GET")
async def get_memory_bank(self, bank_id: str) -> Optional[MemoryBank]: ...
@webmethod(route="/memory/drop", method="DELETE")
async def drop_memory_bank(
self,
bank_id: str,
) -> str: ...
memory_bank_store: MemoryBankStore
# this will just block now until documents are inserted, but it should
# probably return a Job instance which can be polled for completion
@ -154,3 +83,6 @@ class Memory(Protocol):
bank_id: str,
document_ids: List[str],
) -> None: ...
@webmethod(route="/memory/register_memory_bank")
async def register_memory_bank(self, memory_bank: MemoryBankDef) -> None: ...

View file

@ -6,7 +6,7 @@
import asyncio
from typing import List, Optional
from typing import Any, Dict, List, Optional
import fire
import httpx
@ -15,6 +15,25 @@ from termcolor import cprint
from .memory_banks import * # noqa: F403
def deserialize_memory_bank_def(j: Optional[Dict[str, Any]]) -> MemoryBankDef:
if j is None:
return None
if "type" not in j:
raise ValueError("Memory bank type not specified")
type = j["type"]
if type == MemoryBankType.vector.value:
return VectorMemoryBankDef(**j)
elif type == MemoryBankType.keyvalue.value:
return KeyValueMemoryBankDef(**j)
elif type == MemoryBankType.keyword.value:
return KeywordMemoryBankDef(**j)
elif type == MemoryBankType.graph.value:
return GraphMemoryBankDef(**j)
else:
raise ValueError(f"Unknown memory bank type: {type}")
class MemoryBanksClient(MemoryBanks):
def __init__(self, base_url: str):
self.base_url = base_url
@ -25,37 +44,36 @@ class MemoryBanksClient(MemoryBanks):
async def shutdown(self) -> None:
pass
async def list_available_memory_banks(self) -> List[MemoryBankSpec]:
async def list_memory_banks(self) -> List[MemoryBankDef]:
async with httpx.AsyncClient() as client:
response = await client.get(
f"{self.base_url}/memory_banks/list",
headers={"Content-Type": "application/json"},
)
response.raise_for_status()
return [MemoryBankSpec(**x) for x in response.json()]
return [deserialize_memory_bank_def(x) for x in response.json()]
async def get_serving_memory_bank(
self, bank_type: MemoryBankType
) -> Optional[MemoryBankSpec]:
async def get_memory_bank(
self,
identifier: str,
) -> Optional[MemoryBankDef]:
async with httpx.AsyncClient() as client:
response = await client.get(
f"{self.base_url}/memory_banks/get",
params={
"bank_type": bank_type.value,
"identifier": identifier,
},
headers={"Content-Type": "application/json"},
)
response.raise_for_status()
j = response.json()
if j is None:
return None
return MemoryBankSpec(**j)
return deserialize_memory_bank_def(j)
async def run_main(host: str, port: int, stream: bool):
client = MemoryBanksClient(f"http://{host}:{port}")
response = await client.list_available_memory_banks()
response = await client.list_memory_banks()
cprint(f"list_memory_banks response={response}", "green")

View file

@ -4,29 +4,67 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import List, Optional, Protocol
from enum import Enum
from typing import List, Literal, Optional, Protocol, Union
from llama_models.schema_utils import json_schema_type, webmethod
from pydantic import BaseModel, Field
from llama_stack.apis.memory import MemoryBankType
from llama_stack.distribution.datatypes import GenericProviderConfig
from typing_extensions import Annotated
@json_schema_type
class MemoryBankSpec(BaseModel):
bank_type: MemoryBankType
provider_config: GenericProviderConfig = Field(
description="Provider config for the model, including provider_type, and corresponding config. ",
)
class MemoryBankType(Enum):
vector = "vector"
keyvalue = "keyvalue"
keyword = "keyword"
graph = "graph"
class CommonDef(BaseModel):
identifier: str
provider_id: Optional[str] = None
@json_schema_type
class VectorMemoryBankDef(CommonDef):
type: Literal[MemoryBankType.vector.value] = MemoryBankType.vector.value
embedding_model: str
chunk_size_in_tokens: int
overlap_size_in_tokens: Optional[int] = None
@json_schema_type
class KeyValueMemoryBankDef(CommonDef):
type: Literal[MemoryBankType.keyvalue.value] = MemoryBankType.keyvalue.value
@json_schema_type
class KeywordMemoryBankDef(CommonDef):
type: Literal[MemoryBankType.keyword.value] = MemoryBankType.keyword.value
@json_schema_type
class GraphMemoryBankDef(CommonDef):
type: Literal[MemoryBankType.graph.value] = MemoryBankType.graph.value
MemoryBankDef = Annotated[
Union[
VectorMemoryBankDef,
KeyValueMemoryBankDef,
KeywordMemoryBankDef,
GraphMemoryBankDef,
],
Field(discriminator="type"),
]
class MemoryBanks(Protocol):
@webmethod(route="/memory_banks/list", method="GET")
async def list_available_memory_banks(self) -> List[MemoryBankSpec]: ...
async def list_memory_banks(self) -> List[MemoryBankDef]: ...
@webmethod(route="/memory_banks/get", method="GET")
async def get_serving_memory_bank(
self, bank_type: MemoryBankType
) -> Optional[MemoryBankSpec]: ...
async def get_memory_bank(self, identifier: str) -> Optional[MemoryBankDef]: ...
@webmethod(route="/memory_banks/register", method="POST")
async def register_memory_bank(self, memory_bank: MemoryBankDef) -> None: ...

View file

@ -56,7 +56,7 @@ async def run_main(host: str, port: int, stream: bool):
response = await client.list_models()
cprint(f"list_models response={response}", "green")
response = await client.get_model("Meta-Llama3.1-8B-Instruct")
response = await client.get_model("Llama3.1-8B-Instruct")
cprint(f"get_model response={response}", "blue")
response = await client.get_model("Llama-Guard-3-1B")

View file

@ -6,27 +6,32 @@
from typing import List, Optional, Protocol
from llama_models.llama3.api.datatypes import Model
from llama_models.schema_utils import json_schema_type, webmethod
from pydantic import BaseModel, Field
from llama_stack.distribution.datatypes import GenericProviderConfig
@json_schema_type
class ModelServingSpec(BaseModel):
llama_model: Model = Field(
description="All metadatas associated with llama model (defined in llama_models.models.sku_list).",
class ModelDef(BaseModel):
identifier: str = Field(
description="A unique identifier for the model type",
)
provider_config: GenericProviderConfig = Field(
description="Provider config for the model, including provider_type, and corresponding config. ",
llama_model: str = Field(
description="Pointer to the core Llama family model",
)
provider_id: Optional[str] = Field(
default=None, description="The provider instance which serves this model"
)
# For now, we are only supporting core llama models but as soon as finetuned
# and other custom models (for example various quantizations) are allowed, there
# will be more metadata fields here
class Models(Protocol):
@webmethod(route="/models/list", method="GET")
async def list_models(self) -> List[ModelServingSpec]: ...
async def list_models(self) -> List[ModelDef]: ...
@webmethod(route="/models/get", method="GET")
async def get_model(self, core_model_id: str) -> Optional[ModelServingSpec]: ...
async def get_model(self, identifier: str) -> Optional[ModelDef]: ...
@webmethod(route="/models/register", method="POST")
async def register_model(self, model: ModelDef) -> None: ...

View file

@ -96,12 +96,6 @@ async def run_main(host: str, port: int, image_path: str = None):
)
print(response)
response = await client.run_shield(
shield_type="injection_shield",
messages=[message],
)
print(response)
def main(host: str, port: int, image: str = None):
asyncio.run(run_main(host, port, image))

View file

@ -11,6 +11,7 @@ from llama_models.schema_utils import json_schema_type, webmethod
from pydantic import BaseModel
from llama_models.llama3.api.datatypes import * # noqa: F403
from llama_stack.apis.shields import * # noqa: F403
@json_schema_type
@ -37,8 +38,17 @@ class RunShieldResponse(BaseModel):
violation: Optional[SafetyViolation] = None
class ShieldStore(Protocol):
def get_shield(self, identifier: str) -> ShieldDef: ...
class Safety(Protocol):
shield_store: ShieldStore
@webmethod(route="/safety/run_shield")
async def run_shield(
self, shield_type: str, messages: List[Message], params: Dict[str, Any] = None
) -> RunShieldResponse: ...
@webmethod(route="/safety/register_shield")
async def register_shield(self, shield: ShieldDef) -> None: ...

View file

@ -4,25 +4,43 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import List, Optional, Protocol
from enum import Enum
from typing import Any, Dict, List, Optional, Protocol
from llama_models.schema_utils import json_schema_type, webmethod
from pydantic import BaseModel, Field
from llama_stack.distribution.datatypes import GenericProviderConfig
@json_schema_type
class ShieldSpec(BaseModel):
shield_type: str
provider_config: GenericProviderConfig = Field(
description="Provider config for the model, including provider_type, and corresponding config. ",
class ShieldType(Enum):
generic_content_shield = "generic_content_shield"
llama_guard = "llama_guard"
code_scanner = "code_scanner"
prompt_guard = "prompt_guard"
class ShieldDef(BaseModel):
identifier: str = Field(
description="A unique identifier for the shield type",
)
type: str = Field(
description="The type of shield this is; the value is one of the ShieldType enum"
)
provider_id: Optional[str] = Field(
default=None, description="The provider instance which serves this shield"
)
params: Dict[str, Any] = Field(
default_factory=dict,
description="Any additional parameters needed for this shield",
)
class Shields(Protocol):
@webmethod(route="/shields/list", method="GET")
async def list_shields(self) -> List[ShieldSpec]: ...
async def list_shields(self) -> List[ShieldDef]: ...
@webmethod(route="/shields/get", method="GET")
async def get_shield(self, shield_type: str) -> Optional[ShieldSpec]: ...
async def get_shield(self, shield_type: str) -> Optional[ShieldDef]: ...
@webmethod(route="/shields/register", method="POST")
async def register_shield(self, shield: ShieldDef) -> None: ...