mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-06-28 02:53:30 +00:00
formatting
This commit is contained in:
parent
94dfa293a6
commit
b6ccaf1778
33 changed files with 110 additions and 97 deletions
|
@ -63,9 +63,9 @@ class ShieldCallStep(StepCommon):
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
class MemoryRetrievalStep(StepCommon):
|
class MemoryRetrievalStep(StepCommon):
|
||||||
step_type: Literal[StepType.memory_retrieval.value] = (
|
step_type: Literal[
|
||||||
StepType.memory_retrieval.value
|
StepType.memory_retrieval.value
|
||||||
)
|
] = StepType.memory_retrieval.value
|
||||||
memory_bank_ids: List[str]
|
memory_bank_ids: List[str]
|
||||||
documents: List[MemoryBankDocument]
|
documents: List[MemoryBankDocument]
|
||||||
scores: List[float]
|
scores: List[float]
|
||||||
|
@ -140,9 +140,9 @@ class AgenticSystemTurnResponseEventType(Enum):
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
class AgenticSystemTurnResponseStepStartPayload(BaseModel):
|
class AgenticSystemTurnResponseStepStartPayload(BaseModel):
|
||||||
event_type: Literal[AgenticSystemTurnResponseEventType.step_start.value] = (
|
event_type: Literal[
|
||||||
AgenticSystemTurnResponseEventType.step_start.value
|
AgenticSystemTurnResponseEventType.step_start.value
|
||||||
)
|
] = AgenticSystemTurnResponseEventType.step_start.value
|
||||||
step_type: StepType
|
step_type: StepType
|
||||||
step_id: str
|
step_id: str
|
||||||
metadata: Optional[Dict[str, Any]] = Field(default_factory=dict)
|
metadata: Optional[Dict[str, Any]] = Field(default_factory=dict)
|
||||||
|
@ -150,9 +150,9 @@ class AgenticSystemTurnResponseStepStartPayload(BaseModel):
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
class AgenticSystemTurnResponseStepCompletePayload(BaseModel):
|
class AgenticSystemTurnResponseStepCompletePayload(BaseModel):
|
||||||
event_type: Literal[AgenticSystemTurnResponseEventType.step_complete.value] = (
|
event_type: Literal[
|
||||||
AgenticSystemTurnResponseEventType.step_complete.value
|
AgenticSystemTurnResponseEventType.step_complete.value
|
||||||
)
|
] = AgenticSystemTurnResponseEventType.step_complete.value
|
||||||
step_type: StepType
|
step_type: StepType
|
||||||
step_details: Step
|
step_details: Step
|
||||||
|
|
||||||
|
@ -161,9 +161,9 @@ class AgenticSystemTurnResponseStepCompletePayload(BaseModel):
|
||||||
class AgenticSystemTurnResponseStepProgressPayload(BaseModel):
|
class AgenticSystemTurnResponseStepProgressPayload(BaseModel):
|
||||||
model_config = ConfigDict(protected_namespaces=())
|
model_config = ConfigDict(protected_namespaces=())
|
||||||
|
|
||||||
event_type: Literal[AgenticSystemTurnResponseEventType.step_progress.value] = (
|
event_type: Literal[
|
||||||
AgenticSystemTurnResponseEventType.step_progress.value
|
AgenticSystemTurnResponseEventType.step_progress.value
|
||||||
)
|
] = AgenticSystemTurnResponseEventType.step_progress.value
|
||||||
step_type: StepType
|
step_type: StepType
|
||||||
step_id: str
|
step_id: str
|
||||||
|
|
||||||
|
@ -174,17 +174,17 @@ class AgenticSystemTurnResponseStepProgressPayload(BaseModel):
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
class AgenticSystemTurnResponseTurnStartPayload(BaseModel):
|
class AgenticSystemTurnResponseTurnStartPayload(BaseModel):
|
||||||
event_type: Literal[AgenticSystemTurnResponseEventType.turn_start.value] = (
|
event_type: Literal[
|
||||||
AgenticSystemTurnResponseEventType.turn_start.value
|
AgenticSystemTurnResponseEventType.turn_start.value
|
||||||
)
|
] = AgenticSystemTurnResponseEventType.turn_start.value
|
||||||
turn_id: str
|
turn_id: str
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
class AgenticSystemTurnResponseTurnCompletePayload(BaseModel):
|
class AgenticSystemTurnResponseTurnCompletePayload(BaseModel):
|
||||||
event_type: Literal[AgenticSystemTurnResponseEventType.turn_complete.value] = (
|
event_type: Literal[
|
||||||
AgenticSystemTurnResponseEventType.turn_complete.value
|
AgenticSystemTurnResponseEventType.turn_complete.value
|
||||||
)
|
] = AgenticSystemTurnResponseEventType.turn_complete.value
|
||||||
turn: Turn
|
turn: Turn
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -63,36 +63,40 @@ class AgenticSystemStepResponse(BaseModel):
|
||||||
|
|
||||||
|
|
||||||
class AgenticSystem(Protocol):
|
class AgenticSystem(Protocol):
|
||||||
|
|
||||||
@webmethod(route="/agentic_system/create")
|
@webmethod(route="/agentic_system/create")
|
||||||
async def create_agentic_system(
|
async def create_agentic_system(
|
||||||
self,
|
self,
|
||||||
request: AgenticSystemCreateRequest,
|
request: AgenticSystemCreateRequest,
|
||||||
) -> AgenticSystemCreateResponse: ...
|
) -> AgenticSystemCreateResponse:
|
||||||
|
...
|
||||||
|
|
||||||
@webmethod(route="/agentic_system/turn/create")
|
@webmethod(route="/agentic_system/turn/create")
|
||||||
async def create_agentic_system_turn(
|
async def create_agentic_system_turn(
|
||||||
self,
|
self,
|
||||||
request: AgenticSystemTurnCreateRequest,
|
request: AgenticSystemTurnCreateRequest,
|
||||||
) -> AgenticSystemTurnResponseStreamChunk: ...
|
) -> AgenticSystemTurnResponseStreamChunk:
|
||||||
|
...
|
||||||
|
|
||||||
@webmethod(route="/agentic_system/turn/get")
|
@webmethod(route="/agentic_system/turn/get")
|
||||||
async def get_agentic_system_turn(
|
async def get_agentic_system_turn(
|
||||||
self,
|
self,
|
||||||
agent_id: str,
|
agent_id: str,
|
||||||
turn_id: str,
|
turn_id: str,
|
||||||
) -> Turn: ...
|
) -> Turn:
|
||||||
|
...
|
||||||
|
|
||||||
@webmethod(route="/agentic_system/step/get")
|
@webmethod(route="/agentic_system/step/get")
|
||||||
async def get_agentic_system_step(
|
async def get_agentic_system_step(
|
||||||
self, agent_id: str, turn_id: str, step_id: str
|
self, agent_id: str, turn_id: str, step_id: str
|
||||||
) -> AgenticSystemStepResponse: ...
|
) -> AgenticSystemStepResponse:
|
||||||
|
...
|
||||||
|
|
||||||
@webmethod(route="/agentic_system/session/create")
|
@webmethod(route="/agentic_system/session/create")
|
||||||
async def create_agentic_system_session(
|
async def create_agentic_system_session(
|
||||||
self,
|
self,
|
||||||
request: AgenticSystemSessionCreateRequest,
|
request: AgenticSystemSessionCreateRequest,
|
||||||
) -> AgenticSystemSessionCreateResponse: ...
|
) -> AgenticSystemSessionCreateResponse:
|
||||||
|
...
|
||||||
|
|
||||||
@webmethod(route="/agentic_system/memory_bank/attach")
|
@webmethod(route="/agentic_system/memory_bank/attach")
|
||||||
async def attach_memory_bank_to_agentic_system(
|
async def attach_memory_bank_to_agentic_system(
|
||||||
|
@ -100,7 +104,8 @@ class AgenticSystem(Protocol):
|
||||||
agent_id: str,
|
agent_id: str,
|
||||||
session_id: str,
|
session_id: str,
|
||||||
memory_bank_ids: List[str],
|
memory_bank_ids: List[str],
|
||||||
) -> None: ...
|
) -> None:
|
||||||
|
...
|
||||||
|
|
||||||
@webmethod(route="/agentic_system/memory_bank/detach")
|
@webmethod(route="/agentic_system/memory_bank/detach")
|
||||||
async def detach_memory_bank_from_agentic_system(
|
async def detach_memory_bank_from_agentic_system(
|
||||||
|
@ -108,7 +113,8 @@ class AgenticSystem(Protocol):
|
||||||
agent_id: str,
|
agent_id: str,
|
||||||
session_id: str,
|
session_id: str,
|
||||||
memory_bank_ids: List[str],
|
memory_bank_ids: List[str],
|
||||||
) -> None: ...
|
) -> None:
|
||||||
|
...
|
||||||
|
|
||||||
@webmethod(route="/agentic_system/session/get")
|
@webmethod(route="/agentic_system/session/get")
|
||||||
async def get_agentic_system_session(
|
async def get_agentic_system_session(
|
||||||
|
@ -116,15 +122,18 @@ class AgenticSystem(Protocol):
|
||||||
agent_id: str,
|
agent_id: str,
|
||||||
session_id: str,
|
session_id: str,
|
||||||
turn_ids: Optional[List[str]] = None,
|
turn_ids: Optional[List[str]] = None,
|
||||||
) -> Session: ...
|
) -> Session:
|
||||||
|
...
|
||||||
|
|
||||||
@webmethod(route="/agentic_system/session/delete")
|
@webmethod(route="/agentic_system/session/delete")
|
||||||
async def delete_agentic_system_session(
|
async def delete_agentic_system_session(
|
||||||
self, agent_id: str, session_id: str
|
self, agent_id: str, session_id: str
|
||||||
) -> None: ...
|
) -> None:
|
||||||
|
...
|
||||||
|
|
||||||
@webmethod(route="/agentic_system/delete")
|
@webmethod(route="/agentic_system/delete")
|
||||||
async def delete_agentic_system(
|
async def delete_agentic_system(
|
||||||
self,
|
self,
|
||||||
agent_id: str,
|
agent_id: str,
|
||||||
) -> None: ...
|
) -> None:
|
||||||
|
...
|
||||||
|
|
|
@ -246,7 +246,6 @@ class AgentInstance(ShieldRunnerMixin):
|
||||||
await self.run_shields(messages, shields)
|
await self.run_shields(messages, shields)
|
||||||
|
|
||||||
except SafetyException as e:
|
except SafetyException as e:
|
||||||
|
|
||||||
yield AgenticSystemTurnResponseStreamChunk(
|
yield AgenticSystemTurnResponseStreamChunk(
|
||||||
event=AgenticSystemTurnResponseEvent(
|
event=AgenticSystemTurnResponseEvent(
|
||||||
payload=AgenticSystemTurnResponseStepCompletePayload(
|
payload=AgenticSystemTurnResponseStepCompletePayload(
|
||||||
|
|
|
@ -23,7 +23,6 @@ class SafetyException(Exception): # noqa: N818
|
||||||
|
|
||||||
|
|
||||||
class ShieldRunnerMixin:
|
class ShieldRunnerMixin:
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
safety_api: Safety,
|
safety_api: Safety,
|
||||||
|
|
|
@ -11,7 +11,6 @@ from llama_toolchain.inference.api import Message
|
||||||
|
|
||||||
|
|
||||||
class BaseTool(ABC):
|
class BaseTool(ABC):
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def get_name(self) -> str:
|
def get_name(self) -> str:
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
|
@ -66,7 +66,6 @@ class SingleMessageBuiltinTool(BaseTool):
|
||||||
|
|
||||||
|
|
||||||
class PhotogenTool(SingleMessageBuiltinTool):
|
class PhotogenTool(SingleMessageBuiltinTool):
|
||||||
|
|
||||||
def __init__(self, dump_dir: str) -> None:
|
def __init__(self, dump_dir: str) -> None:
|
||||||
self.dump_dir = dump_dir
|
self.dump_dir = dump_dir
|
||||||
|
|
||||||
|
@ -87,7 +86,6 @@ class PhotogenTool(SingleMessageBuiltinTool):
|
||||||
|
|
||||||
|
|
||||||
class BraveSearchTool(SingleMessageBuiltinTool):
|
class BraveSearchTool(SingleMessageBuiltinTool):
|
||||||
|
|
||||||
def __init__(self, api_key: str) -> None:
|
def __init__(self, api_key: str) -> None:
|
||||||
self.api_key = api_key
|
self.api_key = api_key
|
||||||
|
|
||||||
|
@ -204,7 +202,6 @@ class BraveSearchTool(SingleMessageBuiltinTool):
|
||||||
|
|
||||||
|
|
||||||
class WolframAlphaTool(SingleMessageBuiltinTool):
|
class WolframAlphaTool(SingleMessageBuiltinTool):
|
||||||
|
|
||||||
def __init__(self, api_key: str) -> None:
|
def __init__(self, api_key: str) -> None:
|
||||||
self.api_key = api_key
|
self.api_key = api_key
|
||||||
self.url = "https://api.wolframalpha.com/v2/query"
|
self.url = "https://api.wolframalpha.com/v2/query"
|
||||||
|
@ -287,7 +284,6 @@ class WolframAlphaTool(SingleMessageBuiltinTool):
|
||||||
|
|
||||||
|
|
||||||
class CodeInterpreterTool(BaseTool):
|
class CodeInterpreterTool(BaseTool):
|
||||||
|
|
||||||
def __init__(self) -> None:
|
def __init__(self) -> None:
|
||||||
ctx = CodeExecutionContext(
|
ctx = CodeExecutionContext(
|
||||||
matplotlib_dump_dir=f"/tmp/{os.environ['USER']}_matplotlib_dump",
|
matplotlib_dump_dir=f"/tmp/{os.environ['USER']}_matplotlib_dump",
|
||||||
|
|
|
@ -27,7 +27,6 @@ from llama_toolchain.safety.api.datatypes import BuiltinShield, ShieldDefinition
|
||||||
|
|
||||||
|
|
||||||
class AgenticSystemClientWrapper:
|
class AgenticSystemClientWrapper:
|
||||||
|
|
||||||
def __init__(self, api, system_id, custom_tools):
|
def __init__(self, api, system_id, custom_tools):
|
||||||
self.api = api
|
self.api = api
|
||||||
self.system_id = system_id
|
self.system_id = system_id
|
||||||
|
|
|
@ -10,7 +10,6 @@ from llama_toolchain.cli.subcommand import Subcommand
|
||||||
|
|
||||||
|
|
||||||
class DistributionCreate(Subcommand):
|
class DistributionCreate(Subcommand):
|
||||||
|
|
||||||
def __init__(self, subparsers: argparse._SubParsersAction):
|
def __init__(self, subparsers: argparse._SubParsersAction):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.parser = subparsers.add_parser(
|
self.parser = subparsers.add_parser(
|
||||||
|
|
|
@ -16,7 +16,6 @@ from .start import DistributionStart
|
||||||
|
|
||||||
|
|
||||||
class DistributionParser(Subcommand):
|
class DistributionParser(Subcommand):
|
||||||
|
|
||||||
def __init__(self, subparsers: argparse._SubParsersAction):
|
def __init__(self, subparsers: argparse._SubParsersAction):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.parser = subparsers.add_parser(
|
self.parser = subparsers.add_parser(
|
||||||
|
|
|
@ -11,7 +11,6 @@ from llama_toolchain.cli.subcommand import Subcommand
|
||||||
|
|
||||||
|
|
||||||
class DistributionList(Subcommand):
|
class DistributionList(Subcommand):
|
||||||
|
|
||||||
def __init__(self, subparsers: argparse._SubParsersAction):
|
def __init__(self, subparsers: argparse._SubParsersAction):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.parser = subparsers.add_parser(
|
self.parser = subparsers.add_parser(
|
||||||
|
|
|
@ -14,7 +14,6 @@ from llama_toolchain.common.config_dirs import DISTRIBS_BASE_DIR
|
||||||
|
|
||||||
|
|
||||||
class DistributionStart(Subcommand):
|
class DistributionStart(Subcommand):
|
||||||
|
|
||||||
def __init__(self, subparsers: argparse._SubParsersAction):
|
def __init__(self, subparsers: argparse._SubParsersAction):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.parser = subparsers.add_parser(
|
self.parser = subparsers.add_parser(
|
||||||
|
|
|
@ -27,7 +27,11 @@ class ModelList(Subcommand):
|
||||||
self.parser.set_defaults(func=self._run_model_list_cmd)
|
self.parser.set_defaults(func=self._run_model_list_cmd)
|
||||||
|
|
||||||
def _add_arguments(self):
|
def _add_arguments(self):
|
||||||
self.parser.add_argument('--show-all', action='store_true', help='Show all models (not just defaults)')
|
self.parser.add_argument(
|
||||||
|
"--show-all",
|
||||||
|
action="store_true",
|
||||||
|
help="Show all models (not just defaults)",
|
||||||
|
)
|
||||||
|
|
||||||
def _run_model_list_cmd(self, args: argparse.Namespace) -> None:
|
def _run_model_list_cmd(self, args: argparse.Namespace) -> None:
|
||||||
headers = [
|
headers = [
|
||||||
|
|
|
@ -26,16 +26,19 @@ class Datasets(Protocol):
|
||||||
def create_dataset(
|
def create_dataset(
|
||||||
self,
|
self,
|
||||||
request: CreateDatasetRequest,
|
request: CreateDatasetRequest,
|
||||||
) -> None: ...
|
) -> None:
|
||||||
|
...
|
||||||
|
|
||||||
@webmethod(route="/datasets/get")
|
@webmethod(route="/datasets/get")
|
||||||
def get_dataset(
|
def get_dataset(
|
||||||
self,
|
self,
|
||||||
dataset_uuid: str,
|
dataset_uuid: str,
|
||||||
) -> TrainEvalDataset: ...
|
) -> TrainEvalDataset:
|
||||||
|
...
|
||||||
|
|
||||||
@webmethod(route="/datasets/delete")
|
@webmethod(route="/datasets/delete")
|
||||||
def delete_dataset(
|
def delete_dataset(
|
||||||
self,
|
self,
|
||||||
dataset_uuid: str,
|
dataset_uuid: str,
|
||||||
) -> None: ...
|
) -> None:
|
||||||
|
...
|
||||||
|
|
|
@ -217,7 +217,6 @@ def create_dynamic_typed_route(func: Any):
|
||||||
|
|
||||||
|
|
||||||
def topological_sort(providers: List[ProviderSpec]) -> List[ProviderSpec]:
|
def topological_sort(providers: List[ProviderSpec]) -> List[ProviderSpec]:
|
||||||
|
|
||||||
by_id = {x.api: x for x in providers}
|
by_id = {x.api: x for x in providers}
|
||||||
|
|
||||||
def dfs(a: ProviderSpec, visited: Set[Api], stack: List[Api]):
|
def dfs(a: ProviderSpec, visited: Set[Api], stack: List[Api]):
|
||||||
|
|
|
@ -26,10 +26,8 @@ class SummarizationMetric(Enum):
|
||||||
|
|
||||||
|
|
||||||
class EvaluationJob(BaseModel):
|
class EvaluationJob(BaseModel):
|
||||||
|
|
||||||
job_uuid: str
|
job_uuid: str
|
||||||
|
|
||||||
|
|
||||||
class EvaluationJobLogStream(BaseModel):
|
class EvaluationJobLogStream(BaseModel):
|
||||||
|
|
||||||
job_uuid: str
|
job_uuid: str
|
||||||
|
|
|
@ -48,7 +48,6 @@ class EvaluateSummarizationRequest(EvaluateTaskRequestCommon):
|
||||||
|
|
||||||
|
|
||||||
class EvaluationJobStatusResponse(BaseModel):
|
class EvaluationJobStatusResponse(BaseModel):
|
||||||
|
|
||||||
job_uuid: str
|
job_uuid: str
|
||||||
|
|
||||||
|
|
||||||
|
@ -64,36 +63,42 @@ class Evaluations(Protocol):
|
||||||
def post_evaluate_text_generation(
|
def post_evaluate_text_generation(
|
||||||
self,
|
self,
|
||||||
request: EvaluateTextGenerationRequest,
|
request: EvaluateTextGenerationRequest,
|
||||||
) -> EvaluationJob: ...
|
) -> EvaluationJob:
|
||||||
|
...
|
||||||
|
|
||||||
@webmethod(route="/evaluate/question_answering/")
|
@webmethod(route="/evaluate/question_answering/")
|
||||||
def post_evaluate_question_answering(
|
def post_evaluate_question_answering(
|
||||||
self,
|
self,
|
||||||
request: EvaluateQuestionAnsweringRequest,
|
request: EvaluateQuestionAnsweringRequest,
|
||||||
) -> EvaluationJob: ...
|
) -> EvaluationJob:
|
||||||
|
...
|
||||||
|
|
||||||
@webmethod(route="/evaluate/summarization/")
|
@webmethod(route="/evaluate/summarization/")
|
||||||
def post_evaluate_summarization(
|
def post_evaluate_summarization(
|
||||||
self,
|
self,
|
||||||
request: EvaluateSummarizationRequest,
|
request: EvaluateSummarizationRequest,
|
||||||
) -> EvaluationJob: ...
|
) -> EvaluationJob:
|
||||||
|
...
|
||||||
|
|
||||||
@webmethod(route="/evaluate/jobs")
|
@webmethod(route="/evaluate/jobs")
|
||||||
def get_evaluation_jobs(self) -> List[EvaluationJob]: ...
|
def get_evaluation_jobs(self) -> List[EvaluationJob]:
|
||||||
|
...
|
||||||
|
|
||||||
@webmethod(route="/evaluate/job/status")
|
@webmethod(route="/evaluate/job/status")
|
||||||
def get_evaluation_job_status(
|
def get_evaluation_job_status(self, job_uuid: str) -> EvaluationJobStatusResponse:
|
||||||
self, job_uuid: str
|
...
|
||||||
) -> EvaluationJobStatusResponse: ...
|
|
||||||
|
|
||||||
# sends SSE stream of logs
|
# sends SSE stream of logs
|
||||||
@webmethod(route="/evaluate/job/logs")
|
@webmethod(route="/evaluate/job/logs")
|
||||||
def get_evaluation_job_logstream(self, job_uuid: str) -> EvaluationJobLogStream: ...
|
def get_evaluation_job_logstream(self, job_uuid: str) -> EvaluationJobLogStream:
|
||||||
|
...
|
||||||
|
|
||||||
@webmethod(route="/evaluate/job/cancel")
|
@webmethod(route="/evaluate/job/cancel")
|
||||||
def cancel_evaluation_job(self, job_uuid: str) -> None: ...
|
def cancel_evaluation_job(self, job_uuid: str) -> None:
|
||||||
|
...
|
||||||
|
|
||||||
@webmethod(route="/evaluate/job/artifacts")
|
@webmethod(route="/evaluate/job/artifacts")
|
||||||
def get_evaluation_job_artifacts(
|
def get_evaluation_job_artifacts(
|
||||||
self, job_uuid: str
|
self, job_uuid: str
|
||||||
) -> EvaluationJobArtifactsResponse: ...
|
) -> EvaluationJobArtifactsResponse:
|
||||||
|
...
|
||||||
|
|
|
@ -97,27 +97,30 @@ class BatchChatCompletionResponse(BaseModel):
|
||||||
|
|
||||||
|
|
||||||
class Inference(Protocol):
|
class Inference(Protocol):
|
||||||
|
|
||||||
@webmethod(route="/inference/completion")
|
@webmethod(route="/inference/completion")
|
||||||
async def completion(
|
async def completion(
|
||||||
self,
|
self,
|
||||||
request: CompletionRequest,
|
request: CompletionRequest,
|
||||||
) -> Union[CompletionResponse, CompletionResponseStreamChunk]: ...
|
) -> Union[CompletionResponse, CompletionResponseStreamChunk]:
|
||||||
|
...
|
||||||
|
|
||||||
@webmethod(route="/inference/chat_completion")
|
@webmethod(route="/inference/chat_completion")
|
||||||
async def chat_completion(
|
async def chat_completion(
|
||||||
self,
|
self,
|
||||||
request: ChatCompletionRequest,
|
request: ChatCompletionRequest,
|
||||||
) -> Union[ChatCompletionResponse, ChatCompletionResponseStreamChunk]: ...
|
) -> Union[ChatCompletionResponse, ChatCompletionResponseStreamChunk]:
|
||||||
|
...
|
||||||
|
|
||||||
@webmethod(route="/inference/batch_completion")
|
@webmethod(route="/inference/batch_completion")
|
||||||
async def batch_completion(
|
async def batch_completion(
|
||||||
self,
|
self,
|
||||||
request: BatchCompletionRequest,
|
request: BatchCompletionRequest,
|
||||||
) -> BatchCompletionResponse: ...
|
) -> BatchCompletionResponse:
|
||||||
|
...
|
||||||
|
|
||||||
@webmethod(route="/inference/batch_chat_completion")
|
@webmethod(route="/inference/batch_chat_completion")
|
||||||
async def batch_chat_completion(
|
async def batch_chat_completion(
|
||||||
self,
|
self,
|
||||||
request: BatchChatCompletionRequest,
|
request: BatchChatCompletionRequest,
|
||||||
) -> BatchChatCompletionResponse: ...
|
) -> BatchChatCompletionResponse:
|
||||||
|
...
|
||||||
|
|
|
@ -7,8 +7,8 @@
|
||||||
from termcolor import cprint
|
from termcolor import cprint
|
||||||
|
|
||||||
from llama_toolchain.inference.api import (
|
from llama_toolchain.inference.api import (
|
||||||
ChatCompletionResponseEventType,
|
ChatCompletionResponseEventType,
|
||||||
ChatCompletionResponseStreamChunk
|
ChatCompletionResponseStreamChunk,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -45,7 +45,6 @@ SEMAPHORE = asyncio.Semaphore(1)
|
||||||
|
|
||||||
|
|
||||||
class MetaReferenceInferenceImpl(Inference):
|
class MetaReferenceInferenceImpl(Inference):
|
||||||
|
|
||||||
def __init__(self, config: MetaReferenceImplConfig) -> None:
|
def __init__(self, config: MetaReferenceImplConfig) -> None:
|
||||||
self.config = config
|
self.config = config
|
||||||
model = resolve_model(config.model)
|
model = resolve_model(config.model)
|
||||||
|
|
|
@ -54,7 +54,6 @@ async def get_provider_impl(
|
||||||
|
|
||||||
|
|
||||||
class OllamaInference(Inference):
|
class OllamaInference(Inference):
|
||||||
|
|
||||||
def __init__(self, config: OllamaImplConfig) -> None:
|
def __init__(self, config: OllamaImplConfig) -> None:
|
||||||
self.config = config
|
self.config = config
|
||||||
|
|
||||||
|
@ -66,7 +65,9 @@ class OllamaInference(Inference):
|
||||||
try:
|
try:
|
||||||
await self.client.ps()
|
await self.client.ps()
|
||||||
except httpx.ConnectError:
|
except httpx.ConnectError:
|
||||||
raise RuntimeError("Ollama Server is not running, start it using `ollama serve` in a separate terminal")
|
raise RuntimeError(
|
||||||
|
"Ollama Server is not running, start it using `ollama serve` in a separate terminal"
|
||||||
|
)
|
||||||
|
|
||||||
async def shutdown(self) -> None:
|
async def shutdown(self) -> None:
|
||||||
pass
|
pass
|
||||||
|
|
|
@ -18,44 +18,52 @@ class MemoryBanks(Protocol):
|
||||||
bank_id: str,
|
bank_id: str,
|
||||||
bank_name: str,
|
bank_name: str,
|
||||||
documents: List[MemoryBankDocument],
|
documents: List[MemoryBankDocument],
|
||||||
) -> None: ...
|
) -> None:
|
||||||
|
...
|
||||||
|
|
||||||
@webmethod(route="/memory_banks/list")
|
@webmethod(route="/memory_banks/list")
|
||||||
def get_memory_banks(self) -> List[MemoryBank]: ...
|
def get_memory_banks(self) -> List[MemoryBank]:
|
||||||
|
...
|
||||||
|
|
||||||
@webmethod(route="/memory_banks/get")
|
@webmethod(route="/memory_banks/get")
|
||||||
def get_memory_bank(self, bank_id: str) -> List[MemoryBank]: ...
|
def get_memory_bank(self, bank_id: str) -> List[MemoryBank]:
|
||||||
|
...
|
||||||
|
|
||||||
@webmethod(route="/memory_banks/drop")
|
@webmethod(route="/memory_banks/drop")
|
||||||
def delete_memory_bank(
|
def delete_memory_bank(
|
||||||
self,
|
self,
|
||||||
bank_id: str,
|
bank_id: str,
|
||||||
) -> str: ...
|
) -> str:
|
||||||
|
...
|
||||||
|
|
||||||
@webmethod(route="/memory_bank/insert")
|
@webmethod(route="/memory_bank/insert")
|
||||||
def post_insert_memory_documents(
|
def post_insert_memory_documents(
|
||||||
self,
|
self,
|
||||||
bank_id: str,
|
bank_id: str,
|
||||||
documents: List[MemoryBankDocument],
|
documents: List[MemoryBankDocument],
|
||||||
) -> None: ...
|
) -> None:
|
||||||
|
...
|
||||||
|
|
||||||
@webmethod(route="/memory_bank/update")
|
@webmethod(route="/memory_bank/update")
|
||||||
def post_update_memory_documents(
|
def post_update_memory_documents(
|
||||||
self,
|
self,
|
||||||
bank_id: str,
|
bank_id: str,
|
||||||
documents: List[MemoryBankDocument],
|
documents: List[MemoryBankDocument],
|
||||||
) -> None: ...
|
) -> None:
|
||||||
|
...
|
||||||
|
|
||||||
@webmethod(route="/memory_bank/get")
|
@webmethod(route="/memory_bank/get")
|
||||||
def get_memory_documents(
|
def get_memory_documents(
|
||||||
self,
|
self,
|
||||||
bank_id: str,
|
bank_id: str,
|
||||||
document_uuids: List[str],
|
document_uuids: List[str],
|
||||||
) -> List[MemoryBankDocument]: ...
|
) -> List[MemoryBankDocument]:
|
||||||
|
...
|
||||||
|
|
||||||
@webmethod(route="/memory_bank/delete")
|
@webmethod(route="/memory_bank/delete")
|
||||||
def delete_memory_documents(
|
def delete_memory_documents(
|
||||||
self,
|
self,
|
||||||
bank_id: str,
|
bank_id: str,
|
||||||
document_uuids: List[str],
|
document_uuids: List[str],
|
||||||
) -> List[str]: ...
|
) -> List[str]:
|
||||||
|
...
|
||||||
|
|
|
@ -11,4 +11,5 @@ from llama_models.schema_utils import webmethod # noqa: F401
|
||||||
from pydantic import BaseModel # noqa: F401
|
from pydantic import BaseModel # noqa: F401
|
||||||
|
|
||||||
|
|
||||||
class Models(Protocol): ...
|
class Models(Protocol):
|
||||||
|
...
|
||||||
|
|
|
@ -64,7 +64,6 @@ class PostTrainingRLHFRequest(BaseModel):
|
||||||
|
|
||||||
|
|
||||||
class PostTrainingJob(BaseModel):
|
class PostTrainingJob(BaseModel):
|
||||||
|
|
||||||
job_uuid: str
|
job_uuid: str
|
||||||
|
|
||||||
|
|
||||||
|
@ -99,30 +98,35 @@ class PostTraining(Protocol):
|
||||||
def post_supervised_fine_tune(
|
def post_supervised_fine_tune(
|
||||||
self,
|
self,
|
||||||
request: PostTrainingSFTRequest,
|
request: PostTrainingSFTRequest,
|
||||||
) -> PostTrainingJob: ...
|
) -> PostTrainingJob:
|
||||||
|
...
|
||||||
|
|
||||||
@webmethod(route="/post_training/preference_optimize")
|
@webmethod(route="/post_training/preference_optimize")
|
||||||
def post_preference_optimize(
|
def post_preference_optimize(
|
||||||
self,
|
self,
|
||||||
request: PostTrainingRLHFRequest,
|
request: PostTrainingRLHFRequest,
|
||||||
) -> PostTrainingJob: ...
|
) -> PostTrainingJob:
|
||||||
|
...
|
||||||
|
|
||||||
@webmethod(route="/post_training/jobs")
|
@webmethod(route="/post_training/jobs")
|
||||||
def get_training_jobs(self) -> List[PostTrainingJob]: ...
|
def get_training_jobs(self) -> List[PostTrainingJob]:
|
||||||
|
...
|
||||||
|
|
||||||
# sends SSE stream of logs
|
# sends SSE stream of logs
|
||||||
@webmethod(route="/post_training/job/logs")
|
@webmethod(route="/post_training/job/logs")
|
||||||
def get_training_job_logstream(self, job_uuid: str) -> PostTrainingJobLogStream: ...
|
def get_training_job_logstream(self, job_uuid: str) -> PostTrainingJobLogStream:
|
||||||
|
...
|
||||||
|
|
||||||
@webmethod(route="/post_training/job/status")
|
@webmethod(route="/post_training/job/status")
|
||||||
def get_training_job_status(
|
def get_training_job_status(self, job_uuid: str) -> PostTrainingJobStatusResponse:
|
||||||
self, job_uuid: str
|
...
|
||||||
) -> PostTrainingJobStatusResponse: ...
|
|
||||||
|
|
||||||
@webmethod(route="/post_training/job/cancel")
|
@webmethod(route="/post_training/job/cancel")
|
||||||
def cancel_training_job(self, job_uuid: str) -> None: ...
|
def cancel_training_job(self, job_uuid: str) -> None:
|
||||||
|
...
|
||||||
|
|
||||||
@webmethod(route="/post_training/job/artifacts")
|
@webmethod(route="/post_training/job/artifacts")
|
||||||
def get_training_job_artifacts(
|
def get_training_job_artifacts(
|
||||||
self, job_uuid: str
|
self, job_uuid: str
|
||||||
) -> PostTrainingJobArtifactsResponse: ...
|
) -> PostTrainingJobArtifactsResponse:
|
||||||
|
...
|
||||||
|
|
|
@ -30,4 +30,5 @@ class RewardScoring(Protocol):
|
||||||
def post_score(
|
def post_score(
|
||||||
self,
|
self,
|
||||||
request: RewardScoringRequest,
|
request: RewardScoringRequest,
|
||||||
) -> Union[RewardScoringResponse]: ...
|
) -> Union[RewardScoringResponse]:
|
||||||
|
...
|
||||||
|
|
|
@ -25,9 +25,9 @@ class RunShieldResponse(BaseModel):
|
||||||
|
|
||||||
|
|
||||||
class Safety(Protocol):
|
class Safety(Protocol):
|
||||||
|
|
||||||
@webmethod(route="/safety/run_shields")
|
@webmethod(route="/safety/run_shields")
|
||||||
async def run_shields(
|
async def run_shields(
|
||||||
self,
|
self,
|
||||||
request: RunShieldRequest,
|
request: RunShieldRequest,
|
||||||
) -> RunShieldResponse: ...
|
) -> RunShieldResponse:
|
||||||
|
...
|
||||||
|
|
|
@ -41,7 +41,6 @@ def resolve_and_get_path(model_name: str) -> str:
|
||||||
|
|
||||||
|
|
||||||
class MetaReferenceSafetyImpl(Safety):
|
class MetaReferenceSafetyImpl(Safety):
|
||||||
|
|
||||||
def __init__(self, config: SafetyConfig) -> None:
|
def __init__(self, config: SafetyConfig) -> None:
|
||||||
self.config = config
|
self.config = config
|
||||||
|
|
||||||
|
|
|
@ -14,7 +14,6 @@ CANNED_RESPONSE_TEXT = "I can't answer that. Can I help with something else?"
|
||||||
|
|
||||||
|
|
||||||
class ShieldBase(ABC):
|
class ShieldBase(ABC):
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
on_violation_action: OnViolationAction = OnViolationAction.RAISE,
|
on_violation_action: OnViolationAction = OnViolationAction.RAISE,
|
||||||
|
@ -60,7 +59,6 @@ class TextShield(ShieldBase):
|
||||||
|
|
||||||
|
|
||||||
class DummyShield(TextShield):
|
class DummyShield(TextShield):
|
||||||
|
|
||||||
def get_shield_type(self) -> ShieldType:
|
def get_shield_type(self) -> ShieldType:
|
||||||
return "dummy"
|
return "dummy"
|
||||||
|
|
||||||
|
|
|
@ -12,7 +12,6 @@ from llama_toolchain.safety.api.datatypes import * # noqa: F403
|
||||||
|
|
||||||
|
|
||||||
class CodeScannerShield(TextShield):
|
class CodeScannerShield(TextShield):
|
||||||
|
|
||||||
def get_shield_type(self) -> ShieldType:
|
def get_shield_type(self) -> ShieldType:
|
||||||
return BuiltinShield.code_scanner_guard
|
return BuiltinShield.code_scanner_guard
|
||||||
|
|
||||||
|
|
|
@ -100,7 +100,6 @@ PROMPT_TEMPLATE = Template(
|
||||||
|
|
||||||
|
|
||||||
class LlamaGuardShield(ShieldBase):
|
class LlamaGuardShield(ShieldBase):
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def instance(
|
def instance(
|
||||||
on_violation_action=OnViolationAction.RAISE,
|
on_violation_action=OnViolationAction.RAISE,
|
||||||
|
@ -166,7 +165,6 @@ class LlamaGuardShield(ShieldBase):
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def get_safety_categories(self) -> List[str]:
|
def get_safety_categories(self) -> List[str]:
|
||||||
|
|
||||||
excluded_categories = self.excluded_categories
|
excluded_categories = self.excluded_categories
|
||||||
if set(excluded_categories) == set(SAFETY_CATEGORIES_TO_CODE_MAP.values()):
|
if set(excluded_categories) == set(SAFETY_CATEGORIES_TO_CODE_MAP.values()):
|
||||||
excluded_categories = []
|
excluded_categories = []
|
||||||
|
@ -181,7 +179,6 @@ class LlamaGuardShield(ShieldBase):
|
||||||
return categories
|
return categories
|
||||||
|
|
||||||
def build_prompt(self, messages: List[Message]) -> str:
|
def build_prompt(self, messages: List[Message]) -> str:
|
||||||
|
|
||||||
categories = self.get_safety_categories()
|
categories = self.get_safety_categories()
|
||||||
categories_str = "\n".join(categories)
|
categories_str = "\n".join(categories)
|
||||||
conversations_str = "\n\n".join(
|
conversations_str = "\n\n".join(
|
||||||
|
@ -225,7 +222,6 @@ class LlamaGuardShield(ShieldBase):
|
||||||
is_violation=False,
|
is_violation=False,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
|
|
||||||
prompt = self.build_prompt(messages)
|
prompt = self.build_prompt(messages)
|
||||||
llama_guard_input = {
|
llama_guard_input = {
|
||||||
"role": "user",
|
"role": "user",
|
||||||
|
|
|
@ -18,7 +18,6 @@ from llama_toolchain.safety.api.datatypes import * # noqa: F403
|
||||||
|
|
||||||
|
|
||||||
class PromptGuardShield(TextShield):
|
class PromptGuardShield(TextShield):
|
||||||
|
|
||||||
class Mode(Enum):
|
class Mode(Enum):
|
||||||
INJECTION = auto()
|
INJECTION = auto()
|
||||||
JAILBREAK = auto()
|
JAILBREAK = auto()
|
||||||
|
|
|
@ -37,4 +37,5 @@ class SyntheticDataGeneration(Protocol):
|
||||||
def post_generate(
|
def post_generate(
|
||||||
self,
|
self,
|
||||||
request: SyntheticDataGenerationRequest,
|
request: SyntheticDataGenerationRequest,
|
||||||
) -> Union[SyntheticDataGenerationResponse]: ...
|
) -> Union[SyntheticDataGenerationResponse]:
|
||||||
|
...
|
||||||
|
|
|
@ -36,7 +36,6 @@ llama download --source huggingface --model-id llama3_1_8b_instruct --hf-token <
|
||||||
|
|
||||||
|
|
||||||
class InferenceTests(unittest.IsolatedAsyncioTestCase):
|
class InferenceTests(unittest.IsolatedAsyncioTestCase):
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def setUpClass(cls):
|
def setUpClass(cls):
|
||||||
# This runs the async setup function
|
# This runs the async setup function
|
||||||
|
|
|
@ -20,7 +20,6 @@ from llama_toolchain.inference.ollama.ollama import get_provider_impl
|
||||||
|
|
||||||
|
|
||||||
class OllamaInferenceTests(unittest.IsolatedAsyncioTestCase):
|
class OllamaInferenceTests(unittest.IsolatedAsyncioTestCase):
|
||||||
|
|
||||||
async def asyncSetUp(self):
|
async def asyncSetUp(self):
|
||||||
ollama_config = OllamaImplConfig(url="http://localhost:11434")
|
ollama_config = OllamaImplConfig(url="http://localhost:11434")
|
||||||
|
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue