formatting

This commit is contained in:
Dalton Flanagan 2024-08-14 14:22:25 -04:00
parent 94dfa293a6
commit b6ccaf1778
33 changed files with 110 additions and 97 deletions

View file

@ -63,9 +63,9 @@ class ShieldCallStep(StepCommon):
@json_schema_type
class MemoryRetrievalStep(StepCommon):
step_type: Literal[StepType.memory_retrieval.value] = (
step_type: Literal[
StepType.memory_retrieval.value
)
] = StepType.memory_retrieval.value
memory_bank_ids: List[str]
documents: List[MemoryBankDocument]
scores: List[float]
@ -140,9 +140,9 @@ class AgenticSystemTurnResponseEventType(Enum):
@json_schema_type
class AgenticSystemTurnResponseStepStartPayload(BaseModel):
event_type: Literal[AgenticSystemTurnResponseEventType.step_start.value] = (
event_type: Literal[
AgenticSystemTurnResponseEventType.step_start.value
)
] = AgenticSystemTurnResponseEventType.step_start.value
step_type: StepType
step_id: str
metadata: Optional[Dict[str, Any]] = Field(default_factory=dict)
@ -150,9 +150,9 @@ class AgenticSystemTurnResponseStepStartPayload(BaseModel):
@json_schema_type
class AgenticSystemTurnResponseStepCompletePayload(BaseModel):
event_type: Literal[AgenticSystemTurnResponseEventType.step_complete.value] = (
event_type: Literal[
AgenticSystemTurnResponseEventType.step_complete.value
)
] = AgenticSystemTurnResponseEventType.step_complete.value
step_type: StepType
step_details: Step
@ -161,9 +161,9 @@ class AgenticSystemTurnResponseStepCompletePayload(BaseModel):
class AgenticSystemTurnResponseStepProgressPayload(BaseModel):
model_config = ConfigDict(protected_namespaces=())
event_type: Literal[AgenticSystemTurnResponseEventType.step_progress.value] = (
event_type: Literal[
AgenticSystemTurnResponseEventType.step_progress.value
)
] = AgenticSystemTurnResponseEventType.step_progress.value
step_type: StepType
step_id: str
@ -174,17 +174,17 @@ class AgenticSystemTurnResponseStepProgressPayload(BaseModel):
@json_schema_type
class AgenticSystemTurnResponseTurnStartPayload(BaseModel):
event_type: Literal[AgenticSystemTurnResponseEventType.turn_start.value] = (
event_type: Literal[
AgenticSystemTurnResponseEventType.turn_start.value
)
] = AgenticSystemTurnResponseEventType.turn_start.value
turn_id: str
@json_schema_type
class AgenticSystemTurnResponseTurnCompletePayload(BaseModel):
event_type: Literal[AgenticSystemTurnResponseEventType.turn_complete.value] = (
event_type: Literal[
AgenticSystemTurnResponseEventType.turn_complete.value
)
] = AgenticSystemTurnResponseEventType.turn_complete.value
turn: Turn

View file

@ -63,36 +63,40 @@ class AgenticSystemStepResponse(BaseModel):
class AgenticSystem(Protocol):
@webmethod(route="/agentic_system/create")
async def create_agentic_system(
self,
request: AgenticSystemCreateRequest,
) -> AgenticSystemCreateResponse: ...
) -> AgenticSystemCreateResponse:
...
@webmethod(route="/agentic_system/turn/create")
async def create_agentic_system_turn(
self,
request: AgenticSystemTurnCreateRequest,
) -> AgenticSystemTurnResponseStreamChunk: ...
) -> AgenticSystemTurnResponseStreamChunk:
...
@webmethod(route="/agentic_system/turn/get")
async def get_agentic_system_turn(
self,
agent_id: str,
turn_id: str,
) -> Turn: ...
) -> Turn:
...
@webmethod(route="/agentic_system/step/get")
async def get_agentic_system_step(
self, agent_id: str, turn_id: str, step_id: str
) -> AgenticSystemStepResponse: ...
) -> AgenticSystemStepResponse:
...
@webmethod(route="/agentic_system/session/create")
async def create_agentic_system_session(
self,
request: AgenticSystemSessionCreateRequest,
) -> AgenticSystemSessionCreateResponse: ...
) -> AgenticSystemSessionCreateResponse:
...
@webmethod(route="/agentic_system/memory_bank/attach")
async def attach_memory_bank_to_agentic_system(
@ -100,7 +104,8 @@ class AgenticSystem(Protocol):
agent_id: str,
session_id: str,
memory_bank_ids: List[str],
) -> None: ...
) -> None:
...
@webmethod(route="/agentic_system/memory_bank/detach")
async def detach_memory_bank_from_agentic_system(
@ -108,7 +113,8 @@ class AgenticSystem(Protocol):
agent_id: str,
session_id: str,
memory_bank_ids: List[str],
) -> None: ...
) -> None:
...
@webmethod(route="/agentic_system/session/get")
async def get_agentic_system_session(
@ -116,15 +122,18 @@ class AgenticSystem(Protocol):
agent_id: str,
session_id: str,
turn_ids: Optional[List[str]] = None,
) -> Session: ...
) -> Session:
...
@webmethod(route="/agentic_system/session/delete")
async def delete_agentic_system_session(
self, agent_id: str, session_id: str
) -> None: ...
) -> None:
...
@webmethod(route="/agentic_system/delete")
async def delete_agentic_system(
self,
agent_id: str,
) -> None: ...
) -> None:
...

View file

@ -246,7 +246,6 @@ class AgentInstance(ShieldRunnerMixin):
await self.run_shields(messages, shields)
except SafetyException as e:
yield AgenticSystemTurnResponseStreamChunk(
event=AgenticSystemTurnResponseEvent(
payload=AgenticSystemTurnResponseStepCompletePayload(

View file

@ -23,7 +23,6 @@ class SafetyException(Exception): # noqa: N818
class ShieldRunnerMixin:
def __init__(
self,
safety_api: Safety,

View file

@ -11,7 +11,6 @@ from llama_toolchain.inference.api import Message
class BaseTool(ABC):
@abstractmethod
def get_name(self) -> str:
raise NotImplementedError

View file

@ -66,7 +66,6 @@ class SingleMessageBuiltinTool(BaseTool):
class PhotogenTool(SingleMessageBuiltinTool):
def __init__(self, dump_dir: str) -> None:
self.dump_dir = dump_dir
@ -87,7 +86,6 @@ class PhotogenTool(SingleMessageBuiltinTool):
class BraveSearchTool(SingleMessageBuiltinTool):
def __init__(self, api_key: str) -> None:
self.api_key = api_key
@ -204,7 +202,6 @@ class BraveSearchTool(SingleMessageBuiltinTool):
class WolframAlphaTool(SingleMessageBuiltinTool):
def __init__(self, api_key: str) -> None:
self.api_key = api_key
self.url = "https://api.wolframalpha.com/v2/query"
@ -287,7 +284,6 @@ class WolframAlphaTool(SingleMessageBuiltinTool):
class CodeInterpreterTool(BaseTool):
def __init__(self) -> None:
ctx = CodeExecutionContext(
matplotlib_dump_dir=f"/tmp/{os.environ['USER']}_matplotlib_dump",

View file

@ -27,7 +27,6 @@ from llama_toolchain.safety.api.datatypes import BuiltinShield, ShieldDefinition
class AgenticSystemClientWrapper:
def __init__(self, api, system_id, custom_tools):
self.api = api
self.system_id = system_id

View file

@ -10,7 +10,6 @@ from llama_toolchain.cli.subcommand import Subcommand
class DistributionCreate(Subcommand):
def __init__(self, subparsers: argparse._SubParsersAction):
super().__init__()
self.parser = subparsers.add_parser(

View file

@ -16,7 +16,6 @@ from .start import DistributionStart
class DistributionParser(Subcommand):
def __init__(self, subparsers: argparse._SubParsersAction):
super().__init__()
self.parser = subparsers.add_parser(

View file

@ -11,7 +11,6 @@ from llama_toolchain.cli.subcommand import Subcommand
class DistributionList(Subcommand):
def __init__(self, subparsers: argparse._SubParsersAction):
super().__init__()
self.parser = subparsers.add_parser(

View file

@ -14,7 +14,6 @@ from llama_toolchain.common.config_dirs import DISTRIBS_BASE_DIR
class DistributionStart(Subcommand):
def __init__(self, subparsers: argparse._SubParsersAction):
super().__init__()
self.parser = subparsers.add_parser(

View file

@ -27,7 +27,11 @@ class ModelList(Subcommand):
self.parser.set_defaults(func=self._run_model_list_cmd)
def _add_arguments(self):
self.parser.add_argument('--show-all', action='store_true', help='Show all models (not just defaults)')
self.parser.add_argument(
"--show-all",
action="store_true",
help="Show all models (not just defaults)",
)
def _run_model_list_cmd(self, args: argparse.Namespace) -> None:
headers = [

View file

@ -26,16 +26,19 @@ class Datasets(Protocol):
def create_dataset(
self,
request: CreateDatasetRequest,
) -> None: ...
) -> None:
...
@webmethod(route="/datasets/get")
def get_dataset(
self,
dataset_uuid: str,
) -> TrainEvalDataset: ...
) -> TrainEvalDataset:
...
@webmethod(route="/datasets/delete")
def delete_dataset(
self,
dataset_uuid: str,
) -> None: ...
) -> None:
...

View file

@ -217,7 +217,6 @@ def create_dynamic_typed_route(func: Any):
def topological_sort(providers: List[ProviderSpec]) -> List[ProviderSpec]:
by_id = {x.api: x for x in providers}
def dfs(a: ProviderSpec, visited: Set[Api], stack: List[Api]):

View file

@ -26,10 +26,8 @@ class SummarizationMetric(Enum):
class EvaluationJob(BaseModel):
job_uuid: str
class EvaluationJobLogStream(BaseModel):
job_uuid: str

View file

@ -48,7 +48,6 @@ class EvaluateSummarizationRequest(EvaluateTaskRequestCommon):
class EvaluationJobStatusResponse(BaseModel):
job_uuid: str
@ -64,36 +63,42 @@ class Evaluations(Protocol):
def post_evaluate_text_generation(
self,
request: EvaluateTextGenerationRequest,
) -> EvaluationJob: ...
) -> EvaluationJob:
...
@webmethod(route="/evaluate/question_answering/")
def post_evaluate_question_answering(
self,
request: EvaluateQuestionAnsweringRequest,
) -> EvaluationJob: ...
) -> EvaluationJob:
...
@webmethod(route="/evaluate/summarization/")
def post_evaluate_summarization(
self,
request: EvaluateSummarizationRequest,
) -> EvaluationJob: ...
) -> EvaluationJob:
...
@webmethod(route="/evaluate/jobs")
def get_evaluation_jobs(self) -> List[EvaluationJob]: ...
def get_evaluation_jobs(self) -> List[EvaluationJob]:
...
@webmethod(route="/evaluate/job/status")
def get_evaluation_job_status(
self, job_uuid: str
) -> EvaluationJobStatusResponse: ...
def get_evaluation_job_status(self, job_uuid: str) -> EvaluationJobStatusResponse:
...
# sends SSE stream of logs
@webmethod(route="/evaluate/job/logs")
def get_evaluation_job_logstream(self, job_uuid: str) -> EvaluationJobLogStream: ...
def get_evaluation_job_logstream(self, job_uuid: str) -> EvaluationJobLogStream:
...
@webmethod(route="/evaluate/job/cancel")
def cancel_evaluation_job(self, job_uuid: str) -> None: ...
def cancel_evaluation_job(self, job_uuid: str) -> None:
...
@webmethod(route="/evaluate/job/artifacts")
def get_evaluation_job_artifacts(
self, job_uuid: str
) -> EvaluationJobArtifactsResponse: ...
) -> EvaluationJobArtifactsResponse:
...

View file

@ -97,27 +97,30 @@ class BatchChatCompletionResponse(BaseModel):
class Inference(Protocol):
@webmethod(route="/inference/completion")
async def completion(
self,
request: CompletionRequest,
) -> Union[CompletionResponse, CompletionResponseStreamChunk]: ...
) -> Union[CompletionResponse, CompletionResponseStreamChunk]:
...
@webmethod(route="/inference/chat_completion")
async def chat_completion(
self,
request: ChatCompletionRequest,
) -> Union[ChatCompletionResponse, ChatCompletionResponseStreamChunk]: ...
) -> Union[ChatCompletionResponse, ChatCompletionResponseStreamChunk]:
...
@webmethod(route="/inference/batch_completion")
async def batch_completion(
self,
request: BatchCompletionRequest,
) -> BatchCompletionResponse: ...
) -> BatchCompletionResponse:
...
@webmethod(route="/inference/batch_chat_completion")
async def batch_chat_completion(
self,
request: BatchChatCompletionRequest,
) -> BatchChatCompletionResponse: ...
) -> BatchChatCompletionResponse:
...

View file

@ -7,8 +7,8 @@
from termcolor import cprint
from llama_toolchain.inference.api import (
ChatCompletionResponseEventType,
ChatCompletionResponseStreamChunk
ChatCompletionResponseEventType,
ChatCompletionResponseStreamChunk,
)

View file

@ -45,7 +45,6 @@ SEMAPHORE = asyncio.Semaphore(1)
class MetaReferenceInferenceImpl(Inference):
def __init__(self, config: MetaReferenceImplConfig) -> None:
self.config = config
model = resolve_model(config.model)

View file

@ -54,7 +54,6 @@ async def get_provider_impl(
class OllamaInference(Inference):
def __init__(self, config: OllamaImplConfig) -> None:
self.config = config
@ -66,7 +65,9 @@ class OllamaInference(Inference):
try:
await self.client.ps()
except httpx.ConnectError:
raise RuntimeError("Ollama Server is not running, start it using `ollama serve` in a separate terminal")
raise RuntimeError(
"Ollama Server is not running, start it using `ollama serve` in a separate terminal"
)
async def shutdown(self) -> None:
pass

View file

@ -18,44 +18,52 @@ class MemoryBanks(Protocol):
bank_id: str,
bank_name: str,
documents: List[MemoryBankDocument],
) -> None: ...
) -> None:
...
@webmethod(route="/memory_banks/list")
def get_memory_banks(self) -> List[MemoryBank]: ...
def get_memory_banks(self) -> List[MemoryBank]:
...
@webmethod(route="/memory_banks/get")
def get_memory_bank(self, bank_id: str) -> List[MemoryBank]: ...
def get_memory_bank(self, bank_id: str) -> List[MemoryBank]:
...
@webmethod(route="/memory_banks/drop")
def delete_memory_bank(
self,
bank_id: str,
) -> str: ...
) -> str:
...
@webmethod(route="/memory_bank/insert")
def post_insert_memory_documents(
self,
bank_id: str,
documents: List[MemoryBankDocument],
) -> None: ...
) -> None:
...
@webmethod(route="/memory_bank/update")
def post_update_memory_documents(
self,
bank_id: str,
documents: List[MemoryBankDocument],
) -> None: ...
) -> None:
...
@webmethod(route="/memory_bank/get")
def get_memory_documents(
self,
bank_id: str,
document_uuids: List[str],
) -> List[MemoryBankDocument]: ...
) -> List[MemoryBankDocument]:
...
@webmethod(route="/memory_bank/delete")
def delete_memory_documents(
self,
bank_id: str,
document_uuids: List[str],
) -> List[str]: ...
) -> List[str]:
...

View file

@ -11,4 +11,5 @@ from llama_models.schema_utils import webmethod # noqa: F401
from pydantic import BaseModel # noqa: F401
class Models(Protocol): ...
class Models(Protocol):
...

View file

@ -64,7 +64,6 @@ class PostTrainingRLHFRequest(BaseModel):
class PostTrainingJob(BaseModel):
job_uuid: str
@ -99,30 +98,35 @@ class PostTraining(Protocol):
def post_supervised_fine_tune(
self,
request: PostTrainingSFTRequest,
) -> PostTrainingJob: ...
) -> PostTrainingJob:
...
@webmethod(route="/post_training/preference_optimize")
def post_preference_optimize(
self,
request: PostTrainingRLHFRequest,
) -> PostTrainingJob: ...
) -> PostTrainingJob:
...
@webmethod(route="/post_training/jobs")
def get_training_jobs(self) -> List[PostTrainingJob]: ...
def get_training_jobs(self) -> List[PostTrainingJob]:
...
# sends SSE stream of logs
@webmethod(route="/post_training/job/logs")
def get_training_job_logstream(self, job_uuid: str) -> PostTrainingJobLogStream: ...
def get_training_job_logstream(self, job_uuid: str) -> PostTrainingJobLogStream:
...
@webmethod(route="/post_training/job/status")
def get_training_job_status(
self, job_uuid: str
) -> PostTrainingJobStatusResponse: ...
def get_training_job_status(self, job_uuid: str) -> PostTrainingJobStatusResponse:
...
@webmethod(route="/post_training/job/cancel")
def cancel_training_job(self, job_uuid: str) -> None: ...
def cancel_training_job(self, job_uuid: str) -> None:
...
@webmethod(route="/post_training/job/artifacts")
def get_training_job_artifacts(
self, job_uuid: str
) -> PostTrainingJobArtifactsResponse: ...
) -> PostTrainingJobArtifactsResponse:
...

View file

@ -30,4 +30,5 @@ class RewardScoring(Protocol):
def post_score(
self,
request: RewardScoringRequest,
) -> Union[RewardScoringResponse]: ...
) -> Union[RewardScoringResponse]:
...

View file

@ -25,9 +25,9 @@ class RunShieldResponse(BaseModel):
class Safety(Protocol):
@webmethod(route="/safety/run_shields")
async def run_shields(
self,
request: RunShieldRequest,
) -> RunShieldResponse: ...
) -> RunShieldResponse:
...

View file

@ -41,7 +41,6 @@ def resolve_and_get_path(model_name: str) -> str:
class MetaReferenceSafetyImpl(Safety):
def __init__(self, config: SafetyConfig) -> None:
self.config = config

View file

@ -14,7 +14,6 @@ CANNED_RESPONSE_TEXT = "I can't answer that. Can I help with something else?"
class ShieldBase(ABC):
def __init__(
self,
on_violation_action: OnViolationAction = OnViolationAction.RAISE,
@ -60,7 +59,6 @@ class TextShield(ShieldBase):
class DummyShield(TextShield):
def get_shield_type(self) -> ShieldType:
return "dummy"

View file

@ -12,7 +12,6 @@ from llama_toolchain.safety.api.datatypes import * # noqa: F403
class CodeScannerShield(TextShield):
def get_shield_type(self) -> ShieldType:
return BuiltinShield.code_scanner_guard

View file

@ -100,7 +100,6 @@ PROMPT_TEMPLATE = Template(
class LlamaGuardShield(ShieldBase):
@staticmethod
def instance(
on_violation_action=OnViolationAction.RAISE,
@ -166,7 +165,6 @@ class LlamaGuardShield(ShieldBase):
return None
def get_safety_categories(self) -> List[str]:
excluded_categories = self.excluded_categories
if set(excluded_categories) == set(SAFETY_CATEGORIES_TO_CODE_MAP.values()):
excluded_categories = []
@ -181,7 +179,6 @@ class LlamaGuardShield(ShieldBase):
return categories
def build_prompt(self, messages: List[Message]) -> str:
categories = self.get_safety_categories()
categories_str = "\n".join(categories)
conversations_str = "\n\n".join(
@ -225,7 +222,6 @@ class LlamaGuardShield(ShieldBase):
is_violation=False,
)
else:
prompt = self.build_prompt(messages)
llama_guard_input = {
"role": "user",

View file

@ -18,7 +18,6 @@ from llama_toolchain.safety.api.datatypes import * # noqa: F403
class PromptGuardShield(TextShield):
class Mode(Enum):
INJECTION = auto()
JAILBREAK = auto()

View file

@ -37,4 +37,5 @@ class SyntheticDataGeneration(Protocol):
def post_generate(
self,
request: SyntheticDataGenerationRequest,
) -> Union[SyntheticDataGenerationResponse]: ...
) -> Union[SyntheticDataGenerationResponse]:
...

View file

@ -36,7 +36,6 @@ llama download --source huggingface --model-id llama3_1_8b_instruct --hf-token <
class InferenceTests(unittest.IsolatedAsyncioTestCase):
@classmethod
def setUpClass(cls):
# This runs the async setup function

View file

@ -20,7 +20,6 @@ from llama_toolchain.inference.ollama.ollama import get_provider_impl
class OllamaInferenceTests(unittest.IsolatedAsyncioTestCase):
async def asyncSetUp(self):
ollama_config = OllamaImplConfig(url="http://localhost:11434")