Fix precommit check after moving to ruff (#927)

Lint check in main branch is failing. This fixes the lint check after we
moved to ruff in https://github.com/meta-llama/llama-stack/pull/921. We
need to move to a `ruff.toml` file as well as fixing and ignoring some
additional checks.

Signed-off-by: Yuan Tang <terrytangyuan@gmail.com>
This commit is contained in:
Yuan Tang 2025-02-02 09:46:45 -05:00 committed by GitHub
parent 4773092dd1
commit 34ab7a3b6c
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
217 changed files with 981 additions and 2681 deletions

View file

@ -11,9 +11,7 @@ from llama_stack.distribution.datatypes import Api, ProviderSpec
from .config import MetaReferenceAgentsImplConfig
async def get_provider_impl(
config: MetaReferenceAgentsImplConfig, deps: Dict[Api, ProviderSpec]
):
async def get_provider_impl(config: MetaReferenceAgentsImplConfig, deps: Dict[Api, ProviderSpec]):
from .agents import MetaReferenceAgentsImpl
impl = MetaReferenceAgentsImpl(

View file

@ -74,9 +74,7 @@ log = logging.getLogger(__name__)
def make_random_string(length: int = 8):
return "".join(
secrets.choice(string.ascii_letters + string.digits) for _ in range(length)
)
return "".join(secrets.choice(string.ascii_letters + string.digits) for _ in range(length))
TOOLS_ATTACHMENT_KEY_REGEX = re.compile(r"__tools_attachment__=(\{.*?\})")
@ -153,9 +151,7 @@ class ChatAgent(ShieldRunnerMixin):
async def create_session(self, name: str) -> str:
return await self.storage.create_session(name)
async def create_and_execute_turn(
self, request: AgentTurnCreateRequest
) -> AsyncGenerator:
async def create_and_execute_turn(self, request: AgentTurnCreateRequest) -> AsyncGenerator:
with tracing.span("create_and_execute_turn") as span:
span.set_attribute("session_id", request.session_id)
span.set_attribute("agent_id", self.agent_id)
@ -206,14 +202,9 @@ class ChatAgent(ShieldRunnerMixin):
output_message = chunk
continue
assert isinstance(
chunk, AgentTurnResponseStreamChunk
), f"Unexpected type {type(chunk)}"
assert isinstance(chunk, AgentTurnResponseStreamChunk), f"Unexpected type {type(chunk)}"
event = chunk.event
if (
event.payload.event_type
== AgentTurnResponseEventType.step_complete.value
):
if event.payload.event_type == AgentTurnResponseEventType.step_complete.value:
steps.append(event.payload.step_details)
yield chunk
@ -388,9 +379,7 @@ class ChatAgent(ShieldRunnerMixin):
tool_defs, tool_to_group = await self._get_tool_defs(toolgroups_for_turn)
if documents:
await self.handle_documents(
session_id, documents, input_messages, tool_defs
)
await self.handle_documents(session_id, documents, input_messages, tool_defs)
if RAG_TOOL_GROUP in toolgroups and len(input_messages) > 0:
with tracing.span(MEMORY_QUERY_TOOL) as span:
@ -408,9 +397,7 @@ class ChatAgent(ShieldRunnerMixin):
vector_db_ids = args.get("vector_db_ids", [])
query_config = args.get("query_config")
if query_config:
query_config = TypeAdapter(RAGQueryConfig).validate_python(
query_config
)
query_config = TypeAdapter(RAGQueryConfig).validate_python(query_config)
else:
# handle someone passing an empty dict
query_config = RAGQueryConfig()
@ -438,9 +425,7 @@ class ChatAgent(ShieldRunnerMixin):
)
)
result = await self.tool_runtime_api.rag_tool.query(
content=concat_interleaved_content(
[msg.content for msg in input_messages]
),
content=concat_interleaved_content([msg.content for msg in input_messages]),
vector_db_ids=vector_db_ids,
query_config=query_config,
)
@ -472,9 +457,7 @@ class ChatAgent(ShieldRunnerMixin):
)
)
)
span.set_attribute(
"input", [m.model_dump_json() for m in input_messages]
)
span.set_attribute("input", [m.model_dump_json() for m in input_messages])
span.set_attribute("output", retrieved_context)
span.set_attribute("tool_name", MEMORY_QUERY_TOOL)
@ -511,9 +494,7 @@ class ChatAgent(ShieldRunnerMixin):
self.agent_config.model,
input_messages,
tools=[
tool
for tool in tool_defs.values()
if tool_to_group.get(tool.tool_name, None) != RAG_TOOL_GROUP
tool for tool in tool_defs.values() if tool_to_group.get(tool.tool_name, None) != RAG_TOOL_GROUP
],
tool_prompt_format=self.agent_config.tool_prompt_format,
response_format=self.agent_config.response_format,
@ -560,12 +541,8 @@ class ChatAgent(ShieldRunnerMixin):
if event.stop_reason is not None:
stop_reason = event.stop_reason
span.set_attribute("stop_reason", stop_reason)
span.set_attribute(
"input", [m.model_dump_json() for m in input_messages]
)
span.set_attribute(
"output", f"content: {content} tool_calls: {tool_calls}"
)
span.set_attribute("input", [m.model_dump_json() for m in input_messages])
span.set_attribute("output", f"content: {content} tool_calls: {tool_calls}")
stop_reason = stop_reason or StopReason.out_of_tokens
@ -667,9 +644,7 @@ class ChatAgent(ShieldRunnerMixin):
toolgroup_args,
tool_to_group,
)
assert (
len(result_messages) == 1
), "Currently not supporting multiple messages"
assert len(result_messages) == 1, "Currently not supporting multiple messages"
result_message = result_messages[0]
span.set_attribute("output", result_message.model_dump_json())
@ -697,9 +672,7 @@ class ChatAgent(ShieldRunnerMixin):
# TODO: add tool-input touchpoint and a "start" event for this step also
# but that needs a lot more refactoring of Tool code potentially
if out_attachment := _interpret_content_as_attachment(
result_message.content
):
if out_attachment := _interpret_content_as_attachment(result_message.content):
# NOTE: when we push this message back to the model, the model may ignore the
# attached file path etc. since the model is trained to only provide a user message
# with the summary. We keep all generated attachments and then attach them to final message
@ -714,22 +687,14 @@ class ChatAgent(ShieldRunnerMixin):
) -> Tuple[Dict[str, ToolDefinition], Dict[str, str]]:
# Determine which tools to include
agent_config_toolgroups = set(
(
toolgroup.name
if isinstance(toolgroup, AgentToolGroupWithArgs)
else toolgroup
)
(toolgroup.name if isinstance(toolgroup, AgentToolGroupWithArgs) else toolgroup)
for toolgroup in self.agent_config.toolgroups
)
toolgroups_for_turn_set = (
agent_config_toolgroups
if toolgroups_for_turn is None
else {
(
toolgroup.name
if isinstance(toolgroup, AgentToolGroupWithArgs)
else toolgroup
)
(toolgroup.name if isinstance(toolgroup, AgentToolGroupWithArgs) else toolgroup)
for toolgroup in toolgroups_for_turn
}
)
@ -759,10 +724,7 @@ class ChatAgent(ShieldRunnerMixin):
continue
tools = await self.tool_groups_api.list_tools(toolgroup_id=toolgroup_name)
for tool_def in tools.data:
if (
toolgroup_name.startswith("builtin")
and toolgroup_name != RAG_TOOL_GROUP
):
if toolgroup_name.startswith("builtin") and toolgroup_name != RAG_TOOL_GROUP:
tool_name = tool_def.identifier
built_in_type = BuiltinTool.brave_search
if tool_name == "web_search":
@ -773,9 +735,7 @@ class ChatAgent(ShieldRunnerMixin):
if tool_def_map.get(built_in_type, None):
raise ValueError(f"Tool {built_in_type} already exists")
tool_def_map[built_in_type] = ToolDefinition(
tool_name=built_in_type
)
tool_def_map[built_in_type] = ToolDefinition(tool_name=built_in_type)
tool_to_group[built_in_type] = tool_def.toolgroup_id
continue
@ -821,9 +781,7 @@ class ChatAgent(ShieldRunnerMixin):
# Save the contents to a tempdir and use its path as a URL if code interpreter is present
if code_interpreter_tool:
for c in content_items:
temp_file_path = os.path.join(
self.tempdir, f"{make_random_string()}.txt"
)
temp_file_path = os.path.join(self.tempdir, f"{make_random_string()}.txt")
with open(temp_file_path, "w") as temp_file:
temp_file.write(c.content)
url_items.append(URL(uri=f"file://{temp_file_path}"))
@ -849,8 +807,7 @@ class ChatAgent(ShieldRunnerMixin):
# we try to load the data from the URLs and content items as a message to inference
# and add it to the last message's context
input_messages[-1].context = "\n".join(
[doc.content for doc in content_items]
+ await load_data_from_urls(url_items)
[doc.content for doc in content_items] + await load_data_from_urls(url_items)
)
async def _ensure_vector_db(self, session_id: str) -> str:
@ -874,9 +831,7 @@ class ChatAgent(ShieldRunnerMixin):
return vector_db_id
async def add_to_session_vector_db(
self, session_id: str, data: List[Document]
) -> None:
async def add_to_session_vector_db(self, session_id: str, data: List[Document]) -> None:
vector_db_id = await self._ensure_vector_db(session_id)
documents = [
RAGDocument(
@ -931,11 +886,7 @@ async def attachment_message(tempdir: str, urls: List[URL]) -> ToolResponseMessa
else:
raise ValueError(f"Unsupported URL {url}")
content.append(
TextContentItem(
text=f'# There is a file accessible to you at "{filepath}"\n'
)
)
content.append(TextContentItem(text=f'# There is a file accessible to you at "{filepath}"\n'))
return ToolResponseMessage(
call_id="",

View file

@ -94,16 +94,12 @@ class MetaReferenceAgentsImpl(Agents):
try:
agent_config = json.loads(agent_config)
except json.JSONDecodeError as e:
raise ValueError(
f"Could not JSON decode agent config for {agent_id}"
) from e
raise ValueError(f"Could not JSON decode agent config for {agent_id}") from e
try:
agent_config = AgentConfig(**agent_config)
except Exception as e:
raise ValueError(
f"Could not validate(?) agent config for {agent_id}"
) from e
raise ValueError(f"Could not validate(?) agent config for {agent_id}") from e
return ChatAgent(
agent_id=agent_id,
@ -115,9 +111,7 @@ class MetaReferenceAgentsImpl(Agents):
tool_runtime_api=self.tool_runtime_api,
tool_groups_api=self.tool_groups_api,
persistence_store=(
self.persistence_store
if agent_config.enable_session_persistence
else self.in_memory_store
self.persistence_store if agent_config.enable_session_persistence else self.in_memory_store
),
)
@ -168,22 +162,14 @@ class MetaReferenceAgentsImpl(Agents):
async for event in agent.create_and_execute_turn(request):
yield event
async def get_agents_turn(
self, agent_id: str, session_id: str, turn_id: str
) -> Turn:
turn = await self.persistence_store.get(
f"session:{agent_id}:{session_id}:{turn_id}"
)
async def get_agents_turn(self, agent_id: str, session_id: str, turn_id: str) -> Turn:
turn = await self.persistence_store.get(f"session:{agent_id}:{session_id}:{turn_id}")
turn = json.loads(turn)
turn = Turn(**turn)
return turn
async def get_agents_step(
self, agent_id: str, session_id: str, turn_id: str, step_id: str
) -> AgentStepResponse:
turn = await self.persistence_store.get(
f"session:{agent_id}:{session_id}:{turn_id}"
)
async def get_agents_step(self, agent_id: str, session_id: str, turn_id: str, step_id: str) -> AgentStepResponse:
turn = await self.persistence_store.get(f"session:{agent_id}:{session_id}:{turn_id}")
turn = json.loads(turn)
turn = Turn(**turn)
steps = turn.steps
@ -203,9 +189,7 @@ class MetaReferenceAgentsImpl(Agents):
turns = []
if turn_ids:
for turn_id in turn_ids:
turn = await self.persistence_store.get(
f"session:{agent_id}:{session_id}:{turn_id}"
)
turn = await self.persistence_store.get(f"session:{agent_id}:{session_id}:{turn_id}")
turn = json.loads(turn)
turn = Turn(**turn)
turns.append(turn)

View file

@ -33,9 +33,7 @@ class ShieldRunnerMixin:
self.input_shields = input_shields
self.output_shields = output_shields
async def run_multiple_shields(
self, messages: List[Message], identifiers: List[str]
) -> None:
async def run_multiple_shields(self, messages: List[Message], identifiers: List[str]) -> None:
responses = await asyncio.gather(
*[
self.safety_api.run_shield(

View file

@ -64,9 +64,7 @@ class MockInferenceAPI:
tool_prompt_format: Optional[ToolPromptFormat] = None,
stream: Optional[bool] = False,
logprobs: Optional[LogProbConfig] = None,
) -> Union[
ChatCompletionResponse, AsyncIterator[ChatCompletionResponseStreamChunk]
]:
) -> Union[ChatCompletionResponse, AsyncIterator[ChatCompletionResponseStreamChunk]]:
async def stream_response():
yield ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent(
@ -104,9 +102,7 @@ class MockInferenceAPI:
class MockSafetyAPI:
async def run_shield(
self, shield_id: str, messages: List[Message]
) -> RunShieldResponse:
async def run_shield(self, shield_id: str, messages: List[Message]) -> RunShieldResponse:
return RunShieldResponse(violation=None)
@ -129,9 +125,7 @@ class MockVectorIOAPI:
class MockToolGroupsAPI:
async def register_tool_group(
self, toolgroup_id: str, provider_id: str, mcp_endpoint=None, args=None
) -> None:
async def register_tool_group(self, toolgroup_id: str, provider_id: str, mcp_endpoint=None, args=None) -> None:
pass
async def get_tool_group(self, toolgroup_id: str) -> ToolGroup:
@ -341,26 +335,21 @@ async def test_chat_agent_complex_turn(get_chat_agent):
assert len(responses) > 0
step_types = [
response.event.payload.step_type
for response in responses
if hasattr(response.event.payload, "step_type")
response.event.payload.step_type for response in responses if hasattr(response.event.payload, "step_type")
]
assert StepType.shield_call in step_types, "Shield call step is missing"
assert StepType.inference in step_types, "Inference step is missing"
event_types = [
response.event.payload.event_type
for response in responses
if hasattr(response.event.payload, "event_type")
response.event.payload.event_type for response in responses if hasattr(response.event.payload, "event_type")
]
assert "turn_start" in event_types, "Start event is missing"
assert "turn_complete" in event_types, "Complete event is missing"
assert any(
isinstance(response.event.payload, AgentTurnResponseTurnCompletePayload)
for response in responses
), "Turn complete event is missing"
assert any(isinstance(response.event.payload, AgentTurnResponseTurnCompletePayload) for response in responses), (
"Turn complete event is missing"
)
turn_complete_payload = next(
response.event.payload
for response in responses
@ -380,9 +369,7 @@ async def test_chat_agent_complex_turn(get_chat_agent):
([MEMORY_TOOLGROUP, CODE_INTERPRETER_TOOLGROUP], True, True), # all tools
],
)
async def test_chat_agent_tools(
get_agents_impl, toolgroups, expected_memory, expected_code_interpreter
):
async def test_chat_agent_tools(get_agents_impl, toolgroups, expected_memory, expected_code_interpreter):
impl = await get_agents_impl
agent_config = AgentConfig(
model="test_model",

View file

@ -172,9 +172,7 @@ class LocalFSDatasetIOImpl(DatasetIO, DatasetsProtocolPrivate):
new_rows_df = pandas.DataFrame(rows)
new_rows_df = dataset_impl._validate_dataset_schema(new_rows_df)
dataset_impl.df = pandas.concat(
[dataset_impl.df, new_rows_df], ignore_index=True
)
dataset_impl.df = pandas.concat([dataset_impl.df, new_rows_df], ignore_index=True)
url = str(dataset_info.dataset_def.url)
parsed_url = urlparse(url)
@ -189,12 +187,8 @@ class LocalFSDatasetIOImpl(DatasetIO, DatasetsProtocolPrivate):
raise ValueError("Data URL must be a base64-encoded CSV")
csv_buffer = dataset_impl.df.to_csv(index=False)
base64_content = base64.b64encode(csv_buffer.encode("utf-8")).decode(
"utf-8"
)
dataset_info.dataset_def.url = URL(
uri=f"data:text/csv;base64,{base64_content}"
)
base64_content = base64.b64encode(csv_buffer.encode("utf-8")).decode("utf-8")
dataset_info.dataset_def.url = URL(uri=f"data:text/csv;base64,{base64_content}")
else:
raise ValueError(
f"Unsupported URL scheme: {parsed_url.scheme}. Only file:// and data: URLs are supported for writing."

View file

@ -91,14 +91,10 @@ class MetaReferenceEvalImpl(
candidate = task_config.eval_candidate
scoring_functions = task_def.scoring_functions
dataset_def = await self.datasets_api.get_dataset(dataset_id=dataset_id)
validate_dataset_schema(
dataset_def.dataset_schema, get_valid_schemas(Api.eval.value)
)
validate_dataset_schema(dataset_def.dataset_schema, get_valid_schemas(Api.eval.value))
all_rows = await self.datasetio_api.get_rows_paginated(
dataset_id=dataset_id,
rows_in_page=(
-1 if task_config.num_examples is None else task_config.num_examples
),
rows_in_page=(-1 if task_config.num_examples is None else task_config.num_examples),
)
res = await self.evaluate_rows(
task_id=task_id,
@ -127,9 +123,7 @@ class MetaReferenceEvalImpl(
input_messages = [UserMessage(**x) for x in input_messages]
# NOTE: only single-turn agent generation is supported. Create a new session for each input row
session_create_response = await self.agents_api.create_agent_session(
agent_id, f"session-{i}"
)
session_create_response = await self.agents_api.create_agent_session(agent_id, f"session-{i}")
session_id = session_create_response.session_id
turn_request = dict(
@ -138,12 +132,7 @@ class MetaReferenceEvalImpl(
messages=input_messages,
stream=True,
)
turn_response = [
chunk
async for chunk in await self.agents_api.create_agent_turn(
**turn_request
)
]
turn_response = [chunk async for chunk in await self.agents_api.create_agent_turn(**turn_request)]
final_event = turn_response[-1].event.payload
# check if there's a memory retrieval step and extract the context
@ -152,14 +141,10 @@ class MetaReferenceEvalImpl(
if step.step_type == StepType.tool_execution.value:
for tool_response in step.tool_responses:
if tool_response.tool_name == MEMORY_QUERY_TOOL:
memory_rag_context = " ".join(
x.text for x in tool_response.content
)
memory_rag_context = " ".join(x.text for x in tool_response.content)
agent_generation = {}
agent_generation[ColumnName.generated_answer.value] = (
final_event.turn.output_message.content
)
agent_generation[ColumnName.generated_answer.value] = final_event.turn.output_message.content
if memory_rag_context:
agent_generation[ColumnName.context.value] = memory_rag_context
@ -171,9 +156,7 @@ class MetaReferenceEvalImpl(
self, input_rows: List[Dict[str, Any]], task_config: EvalTaskConfig
) -> List[Dict[str, Any]]:
candidate = task_config.eval_candidate
assert (
candidate.sampling_params.max_tokens is not None
), "SamplingParams.max_tokens must be provided"
assert candidate.sampling_params.max_tokens is not None, "SamplingParams.max_tokens must be provided"
generations = []
for x in tqdm(input_rows):
@ -184,15 +167,9 @@ class MetaReferenceEvalImpl(
content=input_content,
sampling_params=candidate.sampling_params,
)
generations.append(
{
ColumnName.generated_answer.value: response.completion_message.content
}
)
generations.append({ColumnName.generated_answer.value: response.completion_message.content})
elif ColumnName.chat_completion_input.value in x:
chat_completion_input_str = str(
x[ColumnName.chat_completion_input.value]
)
chat_completion_input_str = str(x[ColumnName.chat_completion_input.value])
input_messages = eval(chat_completion_input_str)
input_messages = [UserMessage(**x) for x in input_messages]
messages = []
@ -204,11 +181,7 @@ class MetaReferenceEvalImpl(
messages=messages,
sampling_params=candidate.sampling_params,
)
generations.append(
{
ColumnName.generated_answer.value: response.completion_message.content
}
)
generations.append({ColumnName.generated_answer.value: response.completion_message.content})
else:
raise ValueError("Invalid input row")
@ -230,10 +203,7 @@ class MetaReferenceEvalImpl(
raise ValueError(f"Invalid candidate type: {candidate.type}")
# scoring with generated_answer
score_input_rows = [
input_r | generated_r
for input_r, generated_r in zip(input_rows, generations)
]
score_input_rows = [input_r | generated_r for input_r, generated_r in zip(input_rows, generations)]
if task_config.type == "app" and task_config.scoring_params is not None:
scoring_functions_dict = {
@ -241,9 +211,7 @@ class MetaReferenceEvalImpl(
for scoring_fn_id in scoring_functions
}
else:
scoring_functions_dict = {
scoring_fn_id: None for scoring_fn_id in scoring_functions
}
scoring_functions_dict = {scoring_fn_id: None for scoring_fn_id in scoring_functions}
score_response = await self.scoring_api.score(
input_rows=score_input_rows, scoring_functions=scoring_functions_dict

View file

@ -40,9 +40,7 @@ class MetaReferenceInferenceConfig(BaseModel):
repos = [m.huggingface_repo for m in permitted_models]
if model not in (descriptors + repos):
model_list = "\n\t".join(repos)
raise ValueError(
f"Unknown model: `{model}`. Choose from [\n\t{model_list}\n]"
)
raise ValueError(f"Unknown model: `{model}`. Choose from [\n\t{model_list}\n]")
return model
@classmethod

View file

@ -83,9 +83,7 @@ class TokenResult(BaseModel):
class Llama:
@staticmethod
def build(
config: Union[
MetaReferenceInferenceConfig, MetaReferenceQuantizedInferenceConfig
],
config: Union[MetaReferenceInferenceConfig, MetaReferenceQuantizedInferenceConfig],
model_id: str,
llama_model: Model,
):
@ -150,9 +148,9 @@ class Llama:
checkpoints = sorted(Path(ckpt_dir).glob("*.pth"))
assert len(checkpoints) > 0, f"no checkpoint files found in {ckpt_dir}"
assert model_parallel_size == len(
checkpoints
), f"Loading a checkpoint for MP={len(checkpoints)} but world size is {model_parallel_size}"
assert model_parallel_size == len(checkpoints), (
f"Loading a checkpoint for MP={len(checkpoints)} but world size is {model_parallel_size}"
)
ckpt_path = checkpoints[get_model_parallel_rank()]
state_dict = torch.load(ckpt_path, map_location="cpu", weights_only=True)
with open(Path(ckpt_dir) / "params.json", "r") as f:
@ -168,9 +166,9 @@ class Llama:
)
tokenizer = Tokenizer.get_instance()
assert (
model_args.vocab_size == tokenizer.n_words
), f"model_args vocab = {model_args.vocab_size} but tokenizer vocab = {tokenizer.n_words}"
assert model_args.vocab_size == tokenizer.n_words, (
f"model_args vocab = {model_args.vocab_size} but tokenizer vocab = {tokenizer.n_words}"
)
if isinstance(config, MetaReferenceQuantizedInferenceConfig):
if isinstance(config.quantization, Fp8QuantizationConfig):
@ -193,10 +191,7 @@ class Llama:
model = convert_to_int4_quantized_model(model, model_args, config)
model.load_state_dict(state_dict, strict=True)
if (
model_args.quantization_args is not None
and model_args.quantization_args.spinquant
):
if model_args.quantization_args is not None and model_args.quantization_args.spinquant:
# Add a wrapper for adding hadamard transform for spinquant.
# This needs to be done after loading the state dict otherwise an error will be raised while
# loading the state dict.
@ -206,9 +201,7 @@ class Llama:
add_hadamard_transform_for_spinquant(model)
else:
raise NotImplementedError(
"Currently int4 and fp8 are the only supported quantization methods."
)
raise NotImplementedError("Currently int4 and fp8 are the only supported quantization methods.")
else:
if device == "cuda":
if torch.cuda.is_bf16_supported():
@ -262,10 +255,7 @@ class Llama:
params = self.model.params
if print_input_tokens:
input_tokens = [
self.formatter.vision_token if t == 128256 else t
for t in model_input.tokens
]
input_tokens = [self.formatter.vision_token if t == 128256 else t for t in model_input.tokens]
log.info("Input to model -> " + self.tokenizer.decode(input_tokens))
prompt_tokens = [model_input.tokens]
@ -287,12 +277,10 @@ class Llama:
mask = model_input.vision.mask if model_input.vision is not None else []
# the method works for bsz > 1 so add a batch dimension
xattn_caches, cross_attention_masks, full_text_row_masked_out_mask = (
self.model.compute_vision_tokens_masks(
batch_images=[images],
batch_masks=[mask],
total_len=total_len,
)
xattn_caches, cross_attention_masks, full_text_row_masked_out_mask = self.model.compute_vision_tokens_masks(
batch_images=[images],
batch_masks=[mask],
total_len=total_len,
)
pad_id = self.tokenizer.pad_id
@ -340,9 +328,7 @@ class Llama:
next_token = next_token.reshape(-1)
# only replace token if prompt has already been generated
next_token = torch.where(
input_text_mask[:, cur_pos], tokens[:, cur_pos], next_token
)
next_token = torch.where(input_text_mask[:, cur_pos], tokens[:, cur_pos], next_token)
tokens[:, cur_pos] = next_token
target = tokens[:, prev_pos + 1 : cur_pos + 1]
@ -365,17 +351,11 @@ class Llama:
reduction="none",
ignore_index=pad_id,
)
eos_reached |= (~input_text_mask[:, cur_pos]) & (
torch.isin(next_token, stop_tokens)
)
eos_reached |= (~input_text_mask[:, cur_pos]) & (torch.isin(next_token, stop_tokens))
yield TokenResult(
token=next_token[0].item(),
text=self.tokenizer.decode(next_token.tolist()),
logprobs=(
token_logprobs[:, cur_pos : cur_pos + 1][0].tolist()
if logprobs
else None
),
logprobs=(token_logprobs[:, cur_pos : cur_pos + 1][0].tolist() if logprobs else None),
)
prev_pos = cur_pos
@ -388,11 +368,7 @@ class Llama:
) -> Generator:
sampling_params = request.sampling_params
max_gen_len = sampling_params.max_tokens
if (
max_gen_len is None
or max_gen_len == 0
or max_gen_len >= self.model.params.max_seq_len
):
if max_gen_len is None or max_gen_len == 0 or max_gen_len >= self.model.params.max_seq_len:
max_gen_len = self.model.params.max_seq_len - 1
model_input = self.formatter.encode_content(request.content)
@ -417,11 +393,7 @@ class Llama:
) -> Generator:
sampling_params = request.sampling_params
max_gen_len = sampling_params.max_tokens
if (
max_gen_len is None
or max_gen_len == 0
or max_gen_len >= self.model.params.max_seq_len
):
if max_gen_len is None or max_gen_len == 0 or max_gen_len >= self.model.params.max_seq_len:
max_gen_len = self.model.params.max_seq_len - 1
temperature, top_p = _infer_sampling_params(sampling_params)
@ -473,9 +445,7 @@ class LogitsProcessor:
self.token_enforcer = token_enforcer
self.mask: Optional[torch.Tensor] = None
def process_logits(
self, tokens: torch.Tensor, scores: torch.Tensor
) -> torch.Tensor:
def process_logits(self, tokens: torch.Tensor, scores: torch.Tensor) -> torch.Tensor:
token_sequence = tokens[0, :].tolist()
allowed_tokens = self.token_enforcer.get_allowed_tokens(token_sequence)
@ -510,9 +480,7 @@ def get_logits_processor(
return LogitsProcessor(token_enforcer)
def _build_regular_tokens_list(
tokenizer: Tokenizer, vocab_size: int
) -> List[Tuple[int, str, bool]]:
def _build_regular_tokens_list(tokenizer: Tokenizer, vocab_size: int) -> List[Tuple[int, str, bool]]:
token_0 = tokenizer.encode("0", bos=False, eos=False)[-1]
regular_tokens = []

View file

@ -80,9 +80,7 @@ class MetaReferenceInferenceImpl(
async def load_model(self, model_id, llama_model) -> None:
log.info(f"Loading model `{model_id}`")
if self.config.create_distributed_process_group:
self.generator = LlamaModelParallelGenerator(
self.config, model_id, llama_model
)
self.generator = LlamaModelParallelGenerator(self.config, model_id, llama_model)
self.generator.start()
else:
self.generator = Llama.build(self.config, model_id, llama_model)
@ -100,9 +98,7 @@ class MetaReferenceInferenceImpl(
"No avaible model yet, please register your requested model or add your model in the resouces first"
)
elif request.model != self.model_id:
raise RuntimeError(
f"Model mismatch: request model: {request.model} != loaded model: {self.model_id}"
)
raise RuntimeError(f"Model mismatch: request model: {request.model} != loaded model: {self.model_id}")
async def unregister_model(self, model_id: str) -> None:
pass
@ -184,13 +180,7 @@ class MetaReferenceInferenceImpl(
if request.logprobs:
assert len(token_result.logprobs) == 1
logprobs = [
TokenLogProbs(
logprobs_by_token={
token_result.text: token_result.logprobs[0]
}
)
]
logprobs = [TokenLogProbs(logprobs_by_token={token_result.text: token_result.logprobs[0]})]
yield CompletionResponseStreamChunk(
delta=text,
@ -212,9 +202,7 @@ class MetaReferenceInferenceImpl(
for x in impl():
yield x
async def _nonstream_completion(
self, request: CompletionRequest
) -> CompletionResponse:
async def _nonstream_completion(self, request: CompletionRequest) -> CompletionResponse:
def impl():
tokens = []
logprobs = []
@ -231,13 +219,7 @@ class MetaReferenceInferenceImpl(
if request.logprobs:
assert len(token_result.logprobs) == 1
logprobs.append(
TokenLogProbs(
logprobs_by_token={
token_result.text: token_result.logprobs[0]
}
)
)
logprobs.append(TokenLogProbs(logprobs_by_token={token_result.text: token_result.logprobs[0]}))
if stop_reason is None:
stop_reason = StopReason.out_of_tokens
@ -289,9 +271,7 @@ class MetaReferenceInferenceImpl(
self.check_model(request)
# augment and rewrite messages depending on the model
request.messages = chat_completion_request_to_messages(
request, self.llama_model.core_model_id.value
)
request.messages = chat_completion_request_to_messages(request, self.llama_model.core_model_id.value)
# download media and convert to raw content so we can send it to the model
request = await convert_request_to_raw(request)
@ -304,9 +284,7 @@ class MetaReferenceInferenceImpl(
else:
return await self._nonstream_chat_completion(request)
async def _nonstream_chat_completion(
self, request: ChatCompletionRequest
) -> ChatCompletionResponse:
async def _nonstream_chat_completion(self, request: ChatCompletionRequest) -> ChatCompletionResponse:
def impl():
tokens = []
logprobs = []
@ -323,20 +301,12 @@ class MetaReferenceInferenceImpl(
if request.logprobs:
assert len(token_result.logprobs) == 1
logprobs.append(
TokenLogProbs(
logprobs_by_token={
token_result.text: token_result.logprobs[0]
}
)
)
logprobs.append(TokenLogProbs(logprobs_by_token={token_result.text: token_result.logprobs[0]}))
if stop_reason is None:
stop_reason = StopReason.out_of_tokens
raw_message = self.generator.formatter.decode_assistant_message(
tokens, stop_reason
)
raw_message = self.generator.formatter.decode_assistant_message(tokens, stop_reason)
return ChatCompletionResponse(
completion_message=CompletionMessage(
content=raw_message.content,
@ -352,9 +322,7 @@ class MetaReferenceInferenceImpl(
else:
return impl()
async def _stream_chat_completion(
self, request: ChatCompletionRequest
) -> AsyncGenerator:
async def _stream_chat_completion(self, request: ChatCompletionRequest) -> AsyncGenerator:
def impl():
yield ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent(
@ -405,13 +373,7 @@ class MetaReferenceInferenceImpl(
if request.logprobs:
assert len(token_result.logprobs) == 1
logprobs.append(
TokenLogProbs(
logprobs_by_token={
token_result.text: token_result.logprobs[0]
}
)
)
logprobs.append(TokenLogProbs(logprobs_by_token={token_result.text: token_result.logprobs[0]}))
yield ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent(
event_type=ChatCompletionResponseEventType.progress,
@ -424,9 +386,7 @@ class MetaReferenceInferenceImpl(
if stop_reason is None:
stop_reason = StopReason.out_of_tokens
message = self.generator.formatter.decode_assistant_message(
tokens, stop_reason
)
message = self.generator.formatter.decode_assistant_message(tokens, stop_reason)
parsed_tool_calls = len(message.tool_calls) > 0
if ipython and not parsed_tool_calls:

View file

@ -91,9 +91,7 @@ class LlamaModelParallelGenerator:
self.group = ModelParallelProcessGroup(
model_parallel_size,
init_model_cb=partial(
init_model_cb, self.config, self.model_id, self.llama_model
),
init_model_cb=partial(init_model_cb, self.config, self.model_id, self.llama_model),
)
self.group.start()
return self

View file

@ -55,47 +55,33 @@ class ProcessingMessageName(str, Enum):
class ReadyRequest(BaseModel):
type: Literal[ProcessingMessageName.ready_request] = (
ProcessingMessageName.ready_request
)
type: Literal[ProcessingMessageName.ready_request] = ProcessingMessageName.ready_request
class ReadyResponse(BaseModel):
type: Literal[ProcessingMessageName.ready_response] = (
ProcessingMessageName.ready_response
)
type: Literal[ProcessingMessageName.ready_response] = ProcessingMessageName.ready_response
class EndSentinel(BaseModel):
type: Literal[ProcessingMessageName.end_sentinel] = (
ProcessingMessageName.end_sentinel
)
type: Literal[ProcessingMessageName.end_sentinel] = ProcessingMessageName.end_sentinel
class CancelSentinel(BaseModel):
type: Literal[ProcessingMessageName.cancel_sentinel] = (
ProcessingMessageName.cancel_sentinel
)
type: Literal[ProcessingMessageName.cancel_sentinel] = ProcessingMessageName.cancel_sentinel
class TaskRequest(BaseModel):
type: Literal[ProcessingMessageName.task_request] = (
ProcessingMessageName.task_request
)
type: Literal[ProcessingMessageName.task_request] = ProcessingMessageName.task_request
task: Union[CompletionRequestWithRawContent, ChatCompletionRequestWithRawContent]
class TaskResponse(BaseModel):
type: Literal[ProcessingMessageName.task_response] = (
ProcessingMessageName.task_response
)
type: Literal[ProcessingMessageName.task_response] = ProcessingMessageName.task_response
result: TokenResult
class ExceptionResponse(BaseModel):
type: Literal[ProcessingMessageName.exception_response] = (
ProcessingMessageName.exception_response
)
type: Literal[ProcessingMessageName.exception_response] = ProcessingMessageName.exception_response
error: str
@ -189,9 +175,7 @@ def retrieve_requests(reply_socket_url: str):
group=get_model_parallel_group(),
)
if isinstance(updates[0], CancelSentinel):
log.info(
"quitting generation loop because request was cancelled"
)
log.info("quitting generation loop because request was cancelled")
break
if mp_rank_0():
@ -350,9 +334,7 @@ class ModelParallelProcessGroup:
def run_inference(
self,
req: Union[
CompletionRequestWithRawContent, ChatCompletionRequestWithRawContent
],
req: Union[CompletionRequestWithRawContent, ChatCompletionRequestWithRawContent],
) -> Generator:
assert not self.running, "inference already running"

View file

@ -19,9 +19,7 @@ try:
log.info("Using efficient FP8 operators in FBGEMM.")
except ImportError:
log.error(
"No efficient FP8 operators. Please install FBGEMM in fp8_requirements.txt."
)
log.error("No efficient FP8 operators. Please install FBGEMM in fp8_requirements.txt.")
raise
import torch
@ -60,14 +58,8 @@ def ffn_swiglu(
num_tokens: Optional[Tensor] = None,
is_memory_bounded: bool = False,
) -> Tensor:
if (
isinstance(w1, Fp8ScaledWeights)
and isinstance(w3, Fp8ScaledWeights)
and isinstance(w2, Fp8ScaledWeights)
):
return ffn_swiglu_fp8_dynamic(
x, w1, w3, w2, w1.activation_scale_ub, num_tokens, is_memory_bounded
)
if isinstance(w1, Fp8ScaledWeights) and isinstance(w3, Fp8ScaledWeights) and isinstance(w2, Fp8ScaledWeights):
return ffn_swiglu_fp8_dynamic(x, w1, w3, w2, w1.activation_scale_ub, num_tokens, is_memory_bounded)
(B, T, D) = x.shape # noqa: N806
(HD_L, D_) = w1.shape # noqa: N806
@ -146,12 +138,8 @@ def fc_fp8_dynamic(
Single w8a8 fc layer with dynamic row-wise scaling.
"""
if isinstance(w, Fp8RowwiseWeights):
xq, x_scale = torch.ops.fbgemm.quantize_fp8_per_row(
x, num_tokens, activation_scale_ub
)
y = torch.ops.fbgemm.f8f8bf16_rowwise(
xq, w.weight, x_scale, w.scale, use_fast_accum=True
)
xq, x_scale = torch.ops.fbgemm.quantize_fp8_per_row(x, num_tokens, activation_scale_ub)
y = torch.ops.fbgemm.f8f8bf16_rowwise(xq, w.weight, x_scale, w.scale, use_fast_accum=True)
del xq
return y

View file

@ -17,8 +17,7 @@ from torch import Tensor
@unittest.skipIf(
not torch.cuda.is_available()
or torch.cuda.get_device_properties(torch.cuda.current_device()).major < 9,
not torch.cuda.is_available() or torch.cuda.get_device_properties(torch.cuda.current_device()).major < 9,
"Skip when H100 is not available",
)
class FP8Tests(unittest.TestCase):

View file

@ -57,9 +57,7 @@ class HadamardModule(torch.nn.Module):
return x
def add_hadamard_transform_for_spinquant(
model: torch.nn.Module, prefix: str = ""
) -> None:
def add_hadamard_transform_for_spinquant(model: torch.nn.Module, prefix: str = "") -> None:
"""
Adds a Hadamard transform to the last linear layer of each feedforward network (FFN) in the model.
This function recursively traverses the model's children and looks for layers that match the pattern
@ -81,12 +79,8 @@ def add_hadamard_transform_for_spinquant(
for module_name, module in model.named_children():
child_full_name = prefix + "." + module_name
if re.search(pattern_last_linear_ffn, child_full_name):
new_module = nn.Sequential(
HadamardModule(group_size=module.in_features), module
)
new_module = nn.Sequential(HadamardModule(group_size=module.in_features), module)
del module
setattr(model, module_name, new_module)
else:
add_hadamard_transform_for_spinquant(
module, (prefix + "." if prefix else prefix) + module_name
)
add_hadamard_transform_for_spinquant(module, (prefix + "." if prefix else prefix) + module_name)

View file

@ -63,12 +63,8 @@ def convert_to_fp8_quantized_model(
# Move weights to GPU with quantization
if llama_model.quantization_format == CheckpointQuantizationFormat.fp8_mixed.value:
log.info("Loading fp8 scales...")
fp8_scales_path = os.path.join(
checkpoint_dir, f"fp8_scales_{get_model_parallel_rank()}.pt"
)
assert os.path.isfile(
fp8_scales_path
), f"fp8_scales_path not found for rank {get_model_parallel_rank()}"
fp8_scales_path = os.path.join(checkpoint_dir, f"fp8_scales_{get_model_parallel_rank()}.pt")
assert os.path.isfile(fp8_scales_path), f"fp8_scales_path not found for rank {get_model_parallel_rank()}"
fp8_scales = torch.load(fp8_scales_path, weights_only=True)
for block in model.layers:
@ -81,9 +77,7 @@ def convert_to_fp8_quantized_model(
param = getattr(block.feed_forward, key)
param.weight = load_fp8(
param.weight,
fp8_scales[
f"{block.layer_id}_feed_forward.{key}_{get_model_parallel_rank()}"
],
fp8_scales[f"{block.layer_id}_feed_forward.{key}_{get_model_parallel_rank()}"],
fp8_activation_scale_ub,
)
else:
@ -172,9 +166,7 @@ class Int8DynActInt4WeightLinearLoRA(Int8DynActInt4WeightLinear):
if prefix + "zeros" not in state_dict:
# Zero-point may not be saved in the state dict. In this case, we assume it's zero.
assert prefix + "scales" in state_dict
state_dict[prefix + "zeros"] = torch.zeros_like(
state_dict[prefix + "scales"]
)
state_dict[prefix + "zeros"] = torch.zeros_like(state_dict[prefix + "scales"])
def forward(self, input_: torch.Tensor) -> torch.Tensor:
module_out = super().forward(input_)
@ -229,9 +221,7 @@ class Int8WeightLinear(torch.nn.Linear):
bias: Whether to use bias.
"""
def __init__(
self, in_features: int, out_features: int, bias: bool = True, device=None
) -> None:
def __init__(self, in_features: int, out_features: int, bias: bool = True, device=None) -> None:
super().__init__(in_features, out_features, bias, device=device)
self._register_load_state_dict_pre_hook(self.load_hook)
@ -295,9 +285,7 @@ def _prepare_model_int4_weight_int8_dynamic_activation(
del module
setattr(model, module_name, quantized_module)
else:
_prepare_model_int4_weight_int8_dynamic_activation(
module, group_size, lora_rank, lora_scale
)
_prepare_model_int4_weight_int8_dynamic_activation(module, group_size, lora_rank, lora_scale)
return model
@ -321,9 +309,7 @@ def convert_to_int4_quantized_model(
group_size = model_args.quantization_args.group_size
if group_size is None:
raise ValueError(
"'group_size' cannot be None in 'quantization_args'. Please specify it."
)
raise ValueError("'group_size' cannot be None in 'quantization_args'. Please specify it.")
if model_args.lora_args is None:
# Certain quantized models (e.g., SpinQuant) may not have LoRA.
@ -333,8 +319,6 @@ def convert_to_int4_quantized_model(
lora_rank = model_args.lora_args.rank
lora_scale = model_args.lora_args.scale
_prepare_model_int4_weight_int8_dynamic_activation(
model, group_size, lora_rank, lora_scale
)
_prepare_model_int4_weight_int8_dynamic_activation(model, group_size, lora_rank, lora_scale)
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
return model.to(device)

View file

@ -76,9 +76,9 @@ def main(
checkpoints = sorted(Path(ckpt_dir).glob("*.pth"))
assert len(checkpoints) > 0, f"no checkpoint files found in {ckpt_dir}"
assert model_parallel_size == len(
checkpoints
), f"Loading a checkpoint for MP={len(checkpoints)} but world size is {model_parallel_size}"
assert model_parallel_size == len(checkpoints), (
f"Loading a checkpoint for MP={len(checkpoints)} but world size is {model_parallel_size}"
)
ckpt_path = checkpoints[get_model_parallel_rank()]
checkpoint = torch.load(ckpt_path, map_location="cpu", weights_only=True)
with open(Path(ckpt_dir) / "params.json", "r") as f:
@ -90,9 +90,9 @@ def main(
**params,
)
tokenizer = Tokenizer(model_path=tokenizer_path)
assert (
model_args.vocab_size == tokenizer.n_words
), f"model_args vocab = {model_args.vocab_size} but tokenizer vocab = {tokenizer.n_words}"
assert model_args.vocab_size == tokenizer.n_words, (
f"model_args vocab = {model_args.vocab_size} but tokenizer vocab = {tokenizer.n_words}"
)
# load on CPU in bf16 so that fp8 conversion does not find an unexpected (fp32, e.g.) datatype
torch.set_default_tensor_type(torch.BFloat16Tensor)
@ -106,9 +106,7 @@ def main(
torch.set_default_tensor_type(torch.cuda.HalfTensor)
log.info(ckpt_path)
assert (
quantized_ckpt_dir is not None
), "QUantized checkpoint directory should not be None"
assert quantized_ckpt_dir is not None, "QUantized checkpoint directory should not be None"
fp8_scales = {}
for block in model.layers:
if isinstance(block, TransformerBlock):
@ -122,9 +120,7 @@ def main(
)
with torch.inference_mode():
block.feed_forward.w1.weight = Parameter(fp8_weight.weight)
fp8_scales[
f"{block.layer_id}_feed_forward.w1_{get_model_parallel_rank()}"
] = fp8_weight.scale
fp8_scales[f"{block.layer_id}_feed_forward.w1_{get_model_parallel_rank()}"] = fp8_weight.scale
fp8_weight = quantize_fp8(
block.feed_forward.w3.weight,
@ -133,9 +129,7 @@ def main(
)
with torch.inference_mode():
block.feed_forward.w3.weight = Parameter(fp8_weight.weight)
fp8_scales[
f"{block.layer_id}_feed_forward.w3_{get_model_parallel_rank()}"
] = fp8_weight.scale
fp8_scales[f"{block.layer_id}_feed_forward.w3_{get_model_parallel_rank()}"] = fp8_weight.scale
fp8_weight = quantize_fp8(
block.feed_forward.w2.weight,
@ -144,13 +138,9 @@ def main(
)
with torch.inference_mode():
block.feed_forward.w2.weight = Parameter(fp8_weight.weight)
fp8_scales[
f"{block.layer_id}_feed_forward.w2_{get_model_parallel_rank()}"
] = fp8_weight.scale
fp8_scales[f"{block.layer_id}_feed_forward.w2_{get_model_parallel_rank()}"] = fp8_weight.scale
fp8_scales_path = os.path.join(
quantized_ckpt_dir, f"fp8_scales_{get_model_parallel_rank()}.pt"
)
fp8_scales_path = os.path.join(quantized_ckpt_dir, f"fp8_scales_{get_model_parallel_rank()}.pt")
torch.save(fp8_scales, fp8_scales_path)
ckpt_path = os.path.join(

View file

@ -10,7 +10,6 @@ from pydantic import BaseModel
class SentenceTransformersInferenceConfig(BaseModel):
@classmethod
def sample_run_config(cls) -> Dict[str, Any]:
return {}

View file

@ -53,7 +53,5 @@ class VLLMConfig(BaseModel):
repos = [m.huggingface_repo for m in permitted_models]
if model not in (descriptors + repos):
model_list = "\n\t".join(repos)
raise ValueError(
f"Unknown model: `{model}`. Choose from [\n\t{model_list}\n]"
)
raise ValueError(f"Unknown model: `{model}`. Choose from [\n\t{model_list}\n]")
return model

View file

@ -176,13 +176,9 @@ class VLLMInferenceImpl(Inference, ModelsProtocolPrivate):
log.info("Sampling params: %s", sampling_params)
request_id = _random_uuid()
prompt = await chat_completion_request_to_prompt(
request, self.config.model, self.formatter
)
prompt = await chat_completion_request_to_prompt(request, self.config.model, self.formatter)
vllm_sampling_params = self._sampling_params(request.sampling_params)
results_generator = self.engine.generate(
prompt, vllm_sampling_params, request_id
)
results_generator = self.engine.generate(prompt, vllm_sampling_params, request_id)
if stream:
return self._stream_chat_completion(request, results_generator)
else:
@ -230,12 +226,8 @@ class VLLMInferenceImpl(Inference, ModelsProtocolPrivate):
)
stream = _generate_and_convert_to_openai_compat()
async for chunk in process_chat_completion_stream_response(
stream, self.formatter
):
async for chunk in process_chat_completion_stream_response(stream, self.formatter):
yield chunk
async def embeddings(
self, model_id: str, contents: List[InterleavedContent]
) -> EmbeddingsResponse:
async def embeddings(self, model_id: str, contents: List[InterleavedContent]) -> EmbeddingsResponse:
raise NotImplementedError()

View file

@ -47,6 +47,4 @@ async def validate_input_dataset_schema(
if dataset_type not in EXPECTED_DATASET_SCHEMA:
raise ValueError(f"Dataset type {dataset_type} is not supported.")
validate_dataset_schema(
dataset_def.dataset_schema, EXPECTED_DATASET_SCHEMA[dataset_type]
)
validate_dataset_schema(dataset_def.dataset_schema, EXPECTED_DATASET_SCHEMA[dataset_type])

View file

@ -42,9 +42,7 @@ class TorchtuneCheckpointer:
self._model_type = ModelType[model_type]
self._output_dir = output_dir
# get ckpt paths
self._checkpoint_path = Path.joinpath(
self._checkpoint_dir, self._checkpoint_file
)
self._checkpoint_path = Path.joinpath(self._checkpoint_dir, self._checkpoint_file)
def load_checkpoint(self) -> Dict[str, Any]:
"""
@ -57,13 +55,9 @@ class TorchtuneCheckpointer:
llama3_vision_meta_to_tune,
)
state_dict[training.MODEL_KEY] = llama3_vision_meta_to_tune(
model_state_dict
)
state_dict[training.MODEL_KEY] = llama3_vision_meta_to_tune(model_state_dict)
else:
state_dict[training.MODEL_KEY] = convert_weights.meta_to_tune(
model_state_dict
)
state_dict[training.MODEL_KEY] = convert_weights.meta_to_tune(model_state_dict)
# llama3_2 has tied weights, so we need to remove the output.weight key
if self._model_type == ModelType.LLAMA3_2:
@ -82,10 +76,7 @@ class TorchtuneCheckpointer:
epoch: int,
adapter_only: bool = False,
) -> str:
model_file_path = (
Path(self._output_dir)
/ f"{self._model_id}-{self._training_algorithm}-{epoch}"
)
model_file_path = Path(self._output_dir) / f"{self._model_id}-{self._training_algorithm}-{epoch}"
model_file_path.mkdir(parents=True, exist_ok=True)
@ -116,22 +107,13 @@ class TorchtuneCheckpointer:
llama3_vision_tune_to_meta,
)
state_dict[training.MODEL_KEY] = llama3_vision_tune_to_meta(
model_state_dict
)
state_dict[training.MODEL_KEY] = llama3_vision_tune_to_meta(model_state_dict)
else:
# llama3_2 has tied weights, so we need to add the output.weight key
if (
self._model_type == ModelType.LLAMA3_2
and "output.weight" not in model_state_dict
):
model_state_dict["output.weight"] = model_state_dict[
"tok_embeddings.weight"
]
if self._model_type == ModelType.LLAMA3_2 and "output.weight" not in model_state_dict:
model_state_dict["output.weight"] = model_state_dict["tok_embeddings.weight"]
state_dict[training.MODEL_KEY] = convert_weights.tune_to_meta(
model_state_dict
)
state_dict[training.MODEL_KEY] = convert_weights.tune_to_meta(model_state_dict)
model_file_name = Path.joinpath(model_file_path, "consolidated.00.pth")

View file

@ -15,18 +15,13 @@ from typing import Any, Mapping
from llama_stack.providers.utils.common.data_schema_validator import ColumnName
def llama_stack_instruct_to_torchtune_instruct(
sample: Mapping[str, Any]
) -> Mapping[str, Any]:
assert (
ColumnName.chat_completion_input.value in sample
and ColumnName.expected_answer.value in sample
), "Invalid input row"
def llama_stack_instruct_to_torchtune_instruct(sample: Mapping[str, Any]) -> Mapping[str, Any]:
assert ColumnName.chat_completion_input.value in sample and ColumnName.expected_answer.value in sample, (
"Invalid input row"
)
input_messages = eval(str(sample[ColumnName.chat_completion_input.value]))
assert (
len(input_messages) == 1
), "llama stack intruct dataset format only supports 1 user message"
assert len(input_messages) == 1, "llama stack intruct dataset format only supports 1 user message"
input_message = input_messages[0]
assert "content" in input_message, "content not found in input message"
@ -48,13 +43,9 @@ def llama_stack_chat_to_torchtune_chat(sample: Mapping[str, Any]) -> Mapping[str
roles = []
conversations = []
for message in dialog:
assert (
"role" in message and "content" in message
), "role and content must in message"
assert "role" in message and "content" in message, "role and content must in message"
roles.append(message["role"])
conversations.append(
{"from": role_map[message["role"]], "value": message["content"]}
)
conversations.append({"from": role_map[message["role"]], "value": message["content"]})
assert roles[0] == "user", "first message must be from user"
assert "assistant" in roles, "at least 1 message should be from assistant"

View file

@ -61,8 +61,7 @@ class SFTDataset(Dataset):
if not ("tokens" in tokenized_dict and "mask" in tokenized_dict):
keys_str = ", ".join(tokenized_dict.keys())
error_message = (
"model_transform returned the following keys: "
f"{keys_str}. Must return 'tokens' and 'mask' as keys."
f"model_transform returned the following keys: {keys_str}. Must return 'tokens' and 'mask' as keys."
)
raise ValueError(error_message)

View file

@ -119,9 +119,7 @@ class TorchtunePostTrainingImpl:
return ListPostTrainingJobsResponse(data=self.jobs_list)
@webmethod(route="/post-training/job/status")
async def get_training_job_status(
self, job_uuid: str
) -> Optional[PostTrainingJobStatusResponse]:
async def get_training_job_status(self, job_uuid: str) -> Optional[PostTrainingJobStatusResponse]:
if job_uuid in self.jobs_status:
return self.jobs_status[job_uuid]
return None
@ -131,12 +129,8 @@ class TorchtunePostTrainingImpl:
raise NotImplementedError("Job cancel is not implemented yet")
@webmethod(route="/post-training/job/artifacts")
async def get_training_job_artifacts(
self, job_uuid: str
) -> Optional[PostTrainingJobArtifactsResponse]:
async def get_training_job_artifacts(self, job_uuid: str) -> Optional[PostTrainingJobArtifactsResponse]:
if job_uuid in self.checkpoints_dict:
checkpoints = self.checkpoints_dict.get(job_uuid, [])
return PostTrainingJobArtifactsResponse(
job_uuid=job_uuid, checkpoints=checkpoints
)
return PostTrainingJobArtifactsResponse(job_uuid=job_uuid, checkpoints=checkpoints)
return None

View file

@ -94,9 +94,7 @@ class LoraFinetuningSingleDevice:
self.job_uuid = job_uuid
self.training_config = training_config
if not isinstance(algorithm_config, LoraFinetuningConfig):
raise ValueError(
"You need to speicifc LoraFinetuningConfig for LoRA finetuning"
)
raise ValueError("You need to speicifc LoraFinetuningConfig for LoRA finetuning")
self.algorithm_config = algorithm_config
self._device = torchtune_utils.get_device(device="cuda")
self._dtype = training.get_dtype(training_config.dtype, device=self._device)
@ -105,10 +103,7 @@ class LoraFinetuningSingleDevice:
def model_checkpoint_dir(model) -> str:
checkpoint_dir = Path(model_local_dir(model.descriptor()))
paths = [
Path(checkpoint_dir / f"consolidated.{ext}")
for ext in ["pth", "00.pth"]
]
paths = [Path(checkpoint_dir / f"consolidated.{ext}") for ext in ["pth", "00.pth"]]
if not any(p.exists() for p in paths):
checkpoint_dir = checkpoint_dir / "original"
@ -123,9 +118,7 @@ class LoraFinetuningSingleDevice:
else:
model = resolve_model(self.model_id)
if model is None:
raise ValueError(
f"{self.model_id} not found. Your model id should be in the llama models SKU list"
)
raise ValueError(f"{self.model_id} not found. Your model id should be in the llama models SKU list")
self.checkpoint_dir = model_checkpoint_dir(model)
self._output_dir = str(DEFAULT_CHECKPOINT_DIR)
@ -196,9 +189,7 @@ class LoraFinetuningSingleDevice:
self._tokenizer = await self._setup_tokenizer()
log.info("Tokenizer is initialized.")
self._optimizer = await self._setup_optimizer(
optimizer_config=self.training_config.optimizer_config
)
self._optimizer = await self._setup_optimizer(optimizer_config=self.training_config.optimizer_config)
log.info("Optimizer is initialized.")
self._loss_fn = CEWithChunkedOutputLoss()
@ -226,13 +217,8 @@ class LoraFinetuningSingleDevice:
# by the dataloader and the max_steps_per_epoch param set by the user and is used
# for logging and tracking training state. This should be computed after the dataloader
# has been setup
self._steps_per_epoch = (
len(self._training_dataloader) // self._gradient_accumulation_steps
)
if (
self.max_steps_per_epoch is not None
and self.max_steps_per_epoch < self._steps_per_epoch
):
self._steps_per_epoch = len(self._training_dataloader) // self._gradient_accumulation_steps
if self.max_steps_per_epoch is not None and self.max_steps_per_epoch < self._steps_per_epoch:
self._steps_per_epoch = self.max_steps_per_epoch
self.global_step = self.epochs_run * self._steps_per_epoch
@ -246,9 +232,7 @@ class LoraFinetuningSingleDevice:
log.info("Learning rate scheduler is initialized.")
# Used to ignore labels for loss computation
self.ignore_labels_cache = torch.full(
(self._batch_size, 1), self._loss_fn.ignore_index, device=self._device
)
self.ignore_labels_cache = torch.full((self._batch_size, 1), self._loss_fn.ignore_index, device=self._device)
async def _setup_model(
self,
@ -282,13 +266,9 @@ class LoraFinetuningSingleDevice:
set_trainable_params(model, self.adapter_params)
if enable_activation_checkpointing:
training.set_activation_checkpointing(
model, auto_wrap_policy={modules.TransformerSelfAttentionLayer}
)
training.set_activation_checkpointing(model, auto_wrap_policy={modules.TransformerSelfAttentionLayer})
base_missing, base_unexpected = model.load_state_dict(
base_model_state_dict, strict=False
)
base_missing, base_unexpected = model.load_state_dict(base_model_state_dict, strict=False)
# This is for any adapters that need to be initialized after base weights
# have been loaded (e.g. DoRA).
@ -297,9 +277,7 @@ class LoraFinetuningSingleDevice:
if hasattr(m, "initialize_dora_magnitude"):
m.initialize_dora_magnitude()
if lora_weights_state_dict:
lora_missing, lora_unexpected = model.load_state_dict(
lora_weights_state_dict, strict=False
)
lora_missing, lora_unexpected = model.load_state_dict(lora_weights_state_dict, strict=False)
else:
lora_missing, lora_unexpected = None, None
validate_missing_and_unexpected_for_lora(
@ -313,14 +291,10 @@ class LoraFinetuningSingleDevice:
)
# Validate model adapter params were loaded in with the expected dtype
training.validate_expected_param_dtype(
self.adapter_params.items(), dtype=self._dtype
)
training.validate_expected_param_dtype(self.adapter_params.items(), dtype=self._dtype)
# activation offloading
self.activations_handling_ctx = training.get_act_offloading_ctx_manager(
model, enable_activation_offloading
)
self.activations_handling_ctx = training.get_act_offloading_ctx_manager(model, enable_activation_offloading)
memory_stats = training.get_memory_stats(device=self._device)
training.log_memory_stats(memory_stats)
@ -456,9 +430,7 @@ class LoraFinetuningSingleDevice:
# Shift labels to compute loss
# equivalent to doing labels[..., 1:] and logits[..., :-1, :]
# But this way we dont need to slice the logits. We just add an ignore index to labels.
labels = torch.hstack(
(labels[..., 1:], self.ignore_labels_cache[: labels.shape[0]])
)
labels = torch.hstack((labels[..., 1:], self.ignore_labels_cache[: labels.shape[0]]))
if not isinstance(logits, list):
labels = labels.reshape(-1)
logits = logits.reshape(-1, logits.size(-1))
@ -487,9 +459,7 @@ class LoraFinetuningSingleDevice:
for curr_epoch in range(self.epochs_run, self.total_epochs):
# Update the sampler to ensure data is correctly shuffled across epochs
# in case shuffle is True
metric_logger = DiskLogger(
log_dir=self._output_dir + f"/{self.model_id}-sft-{curr_epoch}"
)
metric_logger = DiskLogger(log_dir=self._output_dir + f"/{self.model_id}-sft-{curr_epoch}")
self._training_sampler.set_epoch(curr_epoch)
loss_to_log = 0.0
@ -497,8 +467,7 @@ class LoraFinetuningSingleDevice:
for idx, batch in enumerate(self._training_dataloader):
if (
self.max_steps_per_epoch is not None
and (idx // self._gradient_accumulation_steps)
== self.max_steps_per_epoch
and (idx // self._gradient_accumulation_steps) == self.max_steps_per_epoch
):
break
@ -506,9 +475,7 @@ class LoraFinetuningSingleDevice:
# Calculate the number of unmasked tokens in the current batch
# and increment the total number of tokens seen in the step
current_num_tokens = (
batch["labels"] != self._loss_fn.ignore_index
).sum()
current_num_tokens = (batch["labels"] != self._loss_fn.ignore_index).sum()
num_tokens += current_num_tokens
# Loss is normalized by default so we multiply by the number of tokens
@ -533,9 +500,7 @@ class LoraFinetuningSingleDevice:
loss_to_log = running_loss.item() / num_tokens
pbar.update(1)
pbar.set_description(
f"{curr_epoch + 1}|{self.global_step}|Loss: {loss_to_log}"
)
pbar.set_description(f"{curr_epoch + 1}|{self.global_step}|Loss: {loss_to_log}")
time_per_step = time.perf_counter() - t0
log_dict = {

View file

@ -67,10 +67,6 @@ class MetaReferenceCodeScannerSafetyImpl(Safety):
violation = SafetyViolation(
violation_level=(ViolationLevel.ERROR),
user_message="Sorry, I found security concerns in the code.",
metadata={
"violation_type": ",".join(
[issue.pattern_id for issue in result.issues_found]
)
},
metadata={"violation_type": ",".join([issue.pattern_id for issue in result.issues_found])},
)
return RunShieldResponse(violation=violation)

View file

@ -10,9 +10,7 @@ from .config import LlamaGuardConfig
async def get_provider_impl(config: LlamaGuardConfig, deps):
from .llama_guard import LlamaGuardSafetyImpl
assert isinstance(
config, LlamaGuardConfig
), f"Unexpected config type: {type(config)}"
assert isinstance(config, LlamaGuardConfig), f"Unexpected config type: {type(config)}"
impl = LlamaGuardSafetyImpl(config, deps)
await impl.initialize()

View file

@ -102,8 +102,7 @@ LLAMA_GUARD_MODEL_IDS = {
}
MODEL_TO_SAFETY_CATEGORIES_MAP = {
"meta-llama/Llama-Guard-3-8B": DEFAULT_LG_V3_SAFETY_CATEGORIES
+ [CAT_CODE_INTERPRETER_ABUSE],
"meta-llama/Llama-Guard-3-8B": DEFAULT_LG_V3_SAFETY_CATEGORIES + [CAT_CODE_INTERPRETER_ABUSE],
"meta-llama/Llama-Guard-3-1B": DEFAULT_LG_V3_SAFETY_CATEGORIES,
"meta-llama/Llama-Guard-3-11B-Vision": DEFAULT_LG_V3_SAFETY_CATEGORIES,
}
@ -133,9 +132,7 @@ Provide your safety assessment for ONLY THE LAST $agent_type message in the abov
- If unsafe, a second line must include a comma-separated list of violated categories."""
PROMPT_TEMPLATE = Template(
f"{PROMPT_TASK}{SAFETY_CATEGORIES}{PROMPT_CONVERSATION}{PROMPT_INSTRUCTIONS}"
)
PROMPT_TEMPLATE = Template(f"{PROMPT_TASK}{SAFETY_CATEGORIES}{PROMPT_CONVERSATION}{PROMPT_INSTRUCTIONS}")
class LlamaGuardSafetyImpl(Safety, ShieldsProtocolPrivate):
@ -233,9 +230,7 @@ class LlamaGuardShield:
if messages[0].role != Role.user.value:
raise ValueError("Messages must start with user")
if len(messages) >= 2 and (
messages[0].role == Role.user.value and messages[1].role == Role.user.value
):
if len(messages) >= 2 and (messages[0].role == Role.user.value and messages[1].role == Role.user.value):
messages = messages[1:]
for i in range(1, len(messages)):
@ -263,10 +258,7 @@ class LlamaGuardShield:
stream=True,
):
event = chunk.event
if (
event.event_type == ChatCompletionResponseEventType.progress
and event.delta.type == "text"
):
if event.event_type == ChatCompletionResponseEventType.progress and event.delta.type == "text":
content += event.delta.text
content = content.strip()
@ -313,10 +305,7 @@ class LlamaGuardShield:
categories = self.get_safety_categories()
categories_str = "\n".join(categories)
conversations_str = "\n\n".join(
[
f"{m.role.capitalize()}: {interleaved_content_as_str(m.content)}"
for m in messages
]
[f"{m.role.capitalize()}: {interleaved_content_as_str(m.content)}" for m in messages]
)
return PROMPT_TEMPLATE.substitute(
agent_type=messages[-1].role.capitalize(),

View file

@ -46,9 +46,7 @@ class PromptGuardSafetyImpl(Safety, ShieldsProtocolPrivate):
async def register_shield(self, shield: Shield) -> None:
if shield.provider_resource_id != PROMPT_GUARD_MODEL:
raise ValueError(
f"Only {PROMPT_GUARD_MODEL} is supported for Prompt Guard. "
)
raise ValueError(f"Only {PROMPT_GUARD_MODEL} is supported for Prompt Guard. ")
async def run_shield(
self,
@ -71,9 +69,7 @@ class PromptGuardShield:
threshold: float = 0.9,
temperature: float = 1.0,
):
assert (
model_dir is not None
), "Must provide a model directory for prompt injection shield"
assert model_dir is not None, "Must provide a model directory for prompt injection shield"
if temperature <= 0:
raise ValueError("Temperature must be greater than 0")
@ -85,9 +81,7 @@ class PromptGuardShield:
# load model and tokenizer
self.tokenizer = AutoTokenizer.from_pretrained(model_dir)
self.model = AutoModelForSequenceClassification.from_pretrained(
model_dir, device_map=self.device
)
self.model = AutoModelForSequenceClassification.from_pretrained(model_dir, device_map=self.device)
async def run(self, messages: List[Message]) -> RunShieldResponse:
message = messages[-1]
@ -117,10 +111,7 @@ class PromptGuardShield:
"violation_type": f"prompt_injection:embedded={score_embedded},malicious={score_malicious}",
},
)
elif (
self.config.guard_type == PromptGuardType.jailbreak.value
and score_malicious > self.threshold
):
elif self.config.guard_type == PromptGuardType.jailbreak.value and score_malicious > self.threshold:
violation = SafetyViolation(
violation_level=ViolationLevel.ERROR,
violation_type=f"prompt_injection:malicious={score_malicious}",

View file

@ -54,15 +54,11 @@ class BasicScoringImpl(
async def list_scoring_functions(self) -> List[ScoringFn]:
scoring_fn_defs_list = [
fn_def
for impl in self.scoring_fn_id_impls.values()
for fn_def in impl.get_supported_scoring_fn_defs()
fn_def for impl in self.scoring_fn_id_impls.values() for fn_def in impl.get_supported_scoring_fn_defs()
]
for f in scoring_fn_defs_list:
assert f.identifier.startswith(
"basic"
), "All basic scoring fn must have identifier prefixed with 'basic'! "
assert f.identifier.startswith("basic"), "All basic scoring fn must have identifier prefixed with 'basic'! "
return scoring_fn_defs_list
@ -76,9 +72,7 @@ class BasicScoringImpl(
save_results_dataset: bool = False,
) -> ScoreBatchResponse:
dataset_def = await self.datasets_api.get_dataset(dataset_id=dataset_id)
validate_dataset_schema(
dataset_def.dataset_schema, get_valid_schemas(Api.scoring.value)
)
validate_dataset_schema(dataset_def.dataset_schema, get_valid_schemas(Api.scoring.value))
all_rows = await self.datasetio_api.get_rows_paginated(
dataset_id=dataset_id,
@ -108,12 +102,8 @@ class BasicScoringImpl(
raise ValueError(f"Scoring function {scoring_fn_id} is not supported.")
scoring_fn = self.scoring_fn_id_impls[scoring_fn_id]
scoring_fn_params = scoring_functions.get(scoring_fn_id, None)
score_results = await scoring_fn.score(
input_rows, scoring_fn_id, scoring_fn_params
)
agg_results = await scoring_fn.aggregate(
score_results, scoring_fn_id, scoring_fn_params
)
score_results = await scoring_fn.score(input_rows, scoring_fn_id, scoring_fn_params)
agg_results = await scoring_fn.aggregate(score_results, scoring_fn_id, scoring_fn_params)
res[scoring_fn_id] = ScoringResult(
score_rows=score_results,
aggregated_results=agg_results,

View file

@ -32,9 +32,7 @@ class EqualityScoringFn(RegisteredBaseScoringFn):
scoring_params: Optional[ScoringFnParams] = None,
) -> ScoringResultRow:
assert "expected_answer" in input_row, "Expected answer not found in input row."
assert (
"generated_answer" in input_row
), "Generated answer not found in input row."
assert "generated_answer" in input_row, "Generated answer not found in input row."
expected_answer = input_row["expected_answer"]
generated_answer = input_row["generated_answer"]

View file

@ -18,7 +18,5 @@ equality = ScoringFn(
provider_id="basic",
provider_resource_id="equality",
return_type=NumberType(),
params=BasicScoringFnParams(
aggregation_functions=[AggregationFunctionType.accuracy]
),
params=BasicScoringFnParams(aggregation_functions=[AggregationFunctionType.accuracy]),
)

View file

@ -55,9 +55,7 @@ MULTILINGUAL_ANSWER_REGEXES = [
r"Àṣàyàn\s*:",
]
MULTILINGUAL_ANSWER_PATTERN_TEMPLATE = (
r"(?i){}\s*([A-D]|[أ-د]|[অ]|[ব]|[ড]|[ঢ]|[]|[]|[]|[])"
)
MULTILINGUAL_ANSWER_PATTERN_TEMPLATE = r"(?i){}\s*([A-D]|[أ-د]|[অ]|[ব]|[ড]|[ঢ]|[]|[]|[]|[])"
regex_parser_multiple_choice_answer = ScoringFn(
identifier="basic::regex_parser_multiple_choice_answer",
@ -66,10 +64,7 @@ regex_parser_multiple_choice_answer = ScoringFn(
provider_id="basic",
provider_resource_id="regex-parser-multiple-choice-answer",
params=RegexParserScoringFnParams(
parsing_regexes=[
MULTILINGUAL_ANSWER_PATTERN_TEMPLATE.format(x)
for x in MULTILINGUAL_ANSWER_REGEXES
],
parsing_regexes=[MULTILINGUAL_ANSWER_PATTERN_TEMPLATE.format(x) for x in MULTILINGUAL_ANSWER_REGEXES],
aggregation_functions=[AggregationFunctionType.accuracy],
),
)

View file

@ -18,7 +18,5 @@ subset_of = ScoringFn(
return_type=NumberType(),
provider_id="basic",
provider_resource_id="subset-of",
params=BasicScoringFnParams(
aggregation_functions=[AggregationFunctionType.accuracy]
),
params=BasicScoringFnParams(aggregation_functions=[AggregationFunctionType.accuracy]),
)

View file

@ -33,17 +33,14 @@ class RegexParserScoringFn(RegisteredBaseScoringFn):
scoring_fn_identifier: Optional[str] = None,
scoring_params: Optional[ScoringFnParams] = None,
) -> ScoringResultRow:
assert (
scoring_fn_identifier is not None
), "Scoring function identifier not found."
assert scoring_fn_identifier is not None, "Scoring function identifier not found."
fn_def = self.supported_fn_defs_registry[scoring_fn_identifier]
if scoring_params is not None:
fn_def.params = scoring_params
assert (
fn_def.params is not None
and fn_def.params.type == ScoringFnParamsType.regex_parser.value
), f"RegexParserScoringFnParams not found for {fn_def}."
assert fn_def.params is not None and fn_def.params.type == ScoringFnParamsType.regex_parser.value, (
f"RegexParserScoringFnParams not found for {fn_def}."
)
expected_answer = input_row["expected_answer"]
generated_answer = input_row["generated_answer"]

View file

@ -124,12 +124,10 @@ class BraintrustScoringImpl(
self.datasets_api = datasets_api
self.braintrust_evaluators = {
entry.identifier: entry.evaluator
for entry in SUPPORTED_BRAINTRUST_SCORING_FN_ENTRY
entry.identifier: entry.evaluator for entry in SUPPORTED_BRAINTRUST_SCORING_FN_ENTRY
}
self.supported_fn_defs_registry = {
entry.identifier: entry.fn_def
for entry in SUPPORTED_BRAINTRUST_SCORING_FN_ENTRY
entry.identifier: entry.fn_def for entry in SUPPORTED_BRAINTRUST_SCORING_FN_ENTRY
}
async def initialize(self) -> None: ...
@ -139,16 +137,14 @@ class BraintrustScoringImpl(
async def list_scoring_functions(self) -> List[ScoringFn]:
scoring_fn_defs_list = [x for x in self.supported_fn_defs_registry.values()]
for f in scoring_fn_defs_list:
assert f.identifier.startswith(
"braintrust"
), "All braintrust scoring fn must have identifier prefixed with 'braintrust'! "
assert f.identifier.startswith("braintrust"), (
"All braintrust scoring fn must have identifier prefixed with 'braintrust'! "
)
return scoring_fn_defs_list
async def register_scoring_function(self, scoring_fn: ScoringFn) -> None:
raise NotImplementedError(
"Registering scoring function not allowed for braintrust provider"
)
raise NotImplementedError("Registering scoring function not allowed for braintrust provider")
async def set_api_key(self) -> None:
# api key is in the request headers
@ -171,17 +167,13 @@ class BraintrustScoringImpl(
await self.set_api_key()
dataset_def = await self.datasets_api.get_dataset(dataset_id=dataset_id)
validate_dataset_schema(
dataset_def.dataset_schema, get_valid_schemas(Api.scoring.value)
)
validate_dataset_schema(dataset_def.dataset_schema, get_valid_schemas(Api.scoring.value))
all_rows = await self.datasetio_api.get_rows_paginated(
dataset_id=dataset_id,
rows_in_page=-1,
)
res = await self.score(
input_rows=all_rows.rows, scoring_functions=scoring_functions
)
res = await self.score(input_rows=all_rows.rows, scoring_functions=scoring_functions)
if save_results_dataset:
# TODO: persist and register dataset on to server for reading
# self.datasets_api.register_dataset()
@ -222,13 +214,8 @@ class BraintrustScoringImpl(
if scoring_fn_id not in self.supported_fn_defs_registry:
raise ValueError(f"Scoring function {scoring_fn_id} is not supported.")
score_results = [
await self.score_row(input_row, scoring_fn_id)
for input_row in input_rows
]
aggregation_functions = self.supported_fn_defs_registry[
scoring_fn_id
].params.aggregation_functions
score_results = [await self.score_row(input_row, scoring_fn_id) for input_row in input_rows]
aggregation_functions = self.supported_fn_defs_registry[scoring_fn_id].params.aggregation_functions
# override scoring_fn params if provided
if scoring_functions[scoring_fn_id] is not None:

View file

@ -21,7 +21,5 @@ answer_correctness_fn_def = ScoringFn(
provider_id="braintrust",
provider_resource_id="answer-correctness",
return_type=NumberType(),
params=BasicScoringFnParams(
aggregation_functions=[AggregationFunctionType.average]
),
params=BasicScoringFnParams(aggregation_functions=[AggregationFunctionType.average]),
)

View file

@ -20,7 +20,5 @@ answer_relevancy_fn_def = ScoringFn(
provider_id="braintrust",
provider_resource_id="answer-relevancy",
return_type=NumberType(),
params=BasicScoringFnParams(
aggregation_functions=[AggregationFunctionType.average]
),
params=BasicScoringFnParams(aggregation_functions=[AggregationFunctionType.average]),
)

View file

@ -20,7 +20,5 @@ answer_similarity_fn_def = ScoringFn(
provider_id="braintrust",
provider_resource_id="answer-similarity",
return_type=NumberType(),
params=BasicScoringFnParams(
aggregation_functions=[AggregationFunctionType.average]
),
params=BasicScoringFnParams(aggregation_functions=[AggregationFunctionType.average]),
)

View file

@ -20,7 +20,5 @@ context_entity_recall_fn_def = ScoringFn(
provider_id="braintrust",
provider_resource_id="context-entity-recall",
return_type=NumberType(),
params=BasicScoringFnParams(
aggregation_functions=[AggregationFunctionType.average]
),
params=BasicScoringFnParams(aggregation_functions=[AggregationFunctionType.average]),
)

View file

@ -20,7 +20,5 @@ context_precision_fn_def = ScoringFn(
provider_id="braintrust",
provider_resource_id="context-precision",
return_type=NumberType(),
params=BasicScoringFnParams(
aggregation_functions=[AggregationFunctionType.average]
),
params=BasicScoringFnParams(aggregation_functions=[AggregationFunctionType.average]),
)

View file

@ -20,7 +20,5 @@ context_recall_fn_def = ScoringFn(
provider_id="braintrust",
provider_resource_id="context-recall",
return_type=NumberType(),
params=BasicScoringFnParams(
aggregation_functions=[AggregationFunctionType.average]
),
params=BasicScoringFnParams(aggregation_functions=[AggregationFunctionType.average]),
)

View file

@ -14,13 +14,10 @@ from llama_stack.apis.scoring_functions import (
context_relevancy_fn_def = ScoringFn(
identifier="braintrust::context-relevancy",
description=(
"Assesses how relevant the provided context is to the given question. "
"See: github.com/braintrustdata/autoevals"
"Assesses how relevant the provided context is to the given question. See: github.com/braintrustdata/autoevals"
),
provider_id="braintrust",
provider_resource_id="context-relevancy",
return_type=NumberType(),
params=BasicScoringFnParams(
aggregation_functions=[AggregationFunctionType.average]
),
params=BasicScoringFnParams(aggregation_functions=[AggregationFunctionType.average]),
)

View file

@ -21,7 +21,5 @@ factuality_fn_def = ScoringFn(
provider_id="braintrust",
provider_resource_id="factuality",
return_type=NumberType(),
params=BasicScoringFnParams(
aggregation_functions=[AggregationFunctionType.average]
),
params=BasicScoringFnParams(aggregation_functions=[AggregationFunctionType.average]),
)

View file

@ -20,7 +20,5 @@ faithfulness_fn_def = ScoringFn(
provider_id="braintrust",
provider_resource_id="faithfulness",
return_type=NumberType(),
params=BasicScoringFnParams(
aggregation_functions=[AggregationFunctionType.average]
),
params=BasicScoringFnParams(aggregation_functions=[AggregationFunctionType.average]),
)

View file

@ -16,8 +16,6 @@ async def get_provider_impl(
):
from .scoring import LlmAsJudgeScoringImpl
impl = LlmAsJudgeScoringImpl(
config, deps[Api.datasetio], deps[Api.datasets], deps[Api.inference]
)
impl = LlmAsJudgeScoringImpl(config, deps[Api.datasetio], deps[Api.datasets], deps[Api.inference])
await impl.initialize()
return impl

View file

@ -58,15 +58,13 @@ class LlmAsJudgeScoringImpl(
async def list_scoring_functions(self) -> List[ScoringFn]:
scoring_fn_defs_list = [
fn_def
for impl in self.scoring_fn_id_impls.values()
for fn_def in impl.get_supported_scoring_fn_defs()
fn_def for impl in self.scoring_fn_id_impls.values() for fn_def in impl.get_supported_scoring_fn_defs()
]
for f in scoring_fn_defs_list:
assert f.identifier.startswith(
"llm-as-judge"
), "All llm-as-judge scoring fn must have identifier prefixed with 'llm-as-judge'! "
assert f.identifier.startswith("llm-as-judge"), (
"All llm-as-judge scoring fn must have identifier prefixed with 'llm-as-judge'! "
)
return scoring_fn_defs_list
@ -80,9 +78,7 @@ class LlmAsJudgeScoringImpl(
save_results_dataset: bool = False,
) -> ScoreBatchResponse:
dataset_def = await self.datasets_api.get_dataset(dataset_id=dataset_id)
validate_dataset_schema(
dataset_def.dataset_schema, get_valid_schemas(Api.scoring.value)
)
validate_dataset_schema(dataset_def.dataset_schema, get_valid_schemas(Api.scoring.value))
all_rows = await self.datasetio_api.get_rows_paginated(
dataset_id=dataset_id,
@ -112,12 +108,8 @@ class LlmAsJudgeScoringImpl(
raise ValueError(f"Scoring function {scoring_fn_id} is not supported.")
scoring_fn = self.scoring_fn_id_impls[scoring_fn_id]
scoring_fn_params = scoring_functions.get(scoring_fn_id, None)
score_results = await scoring_fn.score(
input_rows, scoring_fn_id, scoring_fn_params
)
agg_results = await scoring_fn.aggregate(
score_results, scoring_fn_id, scoring_fn_params
)
score_results = await scoring_fn.score(input_rows, scoring_fn_id, scoring_fn_params)
agg_results = await scoring_fn.aggregate(score_results, scoring_fn_id, scoring_fn_params)
res[scoring_fn_id] = ScoringResult(
score_rows=score_results,
aggregated_results=agg_results,

View file

@ -38,9 +38,7 @@ class LlmAsJudgeScoringFn(RegisteredBaseScoringFn):
scoring_fn_identifier: Optional[str] = None,
scoring_params: Optional[ScoringFnParams] = None,
) -> ScoringResultRow:
assert (
scoring_fn_identifier is not None
), "Scoring function identifier not found."
assert scoring_fn_identifier is not None, "Scoring function identifier not found."
fn_def = self.supported_fn_defs_registry[scoring_fn_identifier]
# override params if scoring_params is provided
@ -48,12 +46,8 @@ class LlmAsJudgeScoringFn(RegisteredBaseScoringFn):
fn_def.params = scoring_params
assert fn_def.params is not None, f"LLMAsJudgeparams not found for {fn_def}."
assert (
fn_def.params.prompt_template is not None
), "LLM Judge prompt_template not found."
assert (
fn_def.params.judge_score_regexes is not None
), "LLM Judge judge_score_regexes not found."
assert fn_def.params.prompt_template is not None, "LLM Judge prompt_template not found."
assert fn_def.params.judge_score_regexes is not None, "LLM Judge judge_score_regexes not found."
input_query = input_row["input_query"]
expected_answer = input_row["expected_answer"]

View file

@ -44,15 +44,9 @@ class TelemetryConfig(BaseModel):
return v
@classmethod
def sample_run_config(
cls, __distro_dir__: str = "runtime", db_name: str = "trace_store.db"
) -> Dict[str, Any]:
def sample_run_config(cls, __distro_dir__: str = "runtime", db_name: str = "trace_store.db") -> Dict[str, Any]:
return {
"service_name": "${env.OTEL_SERVICE_NAME:llama-stack}",
"sinks": "${env.TELEMETRY_SINKS:console,sqlite}",
"sqlite_db_path": "${env.SQLITE_DB_PATH:~/.llama/"
+ __distro_dir__
+ "/"
+ db_name
+ "}",
"sqlite_db_path": "${env.SQLITE_DB_PATH:~/.llama/" + __distro_dir__ + "/" + db_name + "}",
}

View file

@ -27,7 +27,6 @@ COLORS = {
class ConsoleSpanProcessor(SpanProcessor):
def __init__(self, print_attributes: bool = False):
self.print_attributes = print_attributes
@ -35,9 +34,7 @@ class ConsoleSpanProcessor(SpanProcessor):
if span.attributes and span.attributes.get("__autotraced__"):
return
timestamp = datetime.utcfromtimestamp(span.start_time / 1e9).strftime(
"%H:%M:%S.%f"
)[:-3]
timestamp = datetime.utcfromtimestamp(span.start_time / 1e9).strftime("%H:%M:%S.%f")[:-3]
print(
f"{COLORS['dim']}{timestamp}{COLORS['reset']} "
@ -49,9 +46,7 @@ class ConsoleSpanProcessor(SpanProcessor):
if span.attributes and span.attributes.get("__autotraced__"):
return
timestamp = datetime.utcfromtimestamp(span.end_time / 1e9).strftime(
"%H:%M:%S.%f"
)[:-3]
timestamp = datetime.utcfromtimestamp(span.end_time / 1e9).strftime("%H:%M:%S.%f")[:-3]
span_context = (
f"{COLORS['dim']}{timestamp}{COLORS['reset']} "
@ -79,9 +74,7 @@ class ConsoleSpanProcessor(SpanProcessor):
print(f" {COLORS['dim']}{key}: {str_value}{COLORS['reset']}")
for event in span.events:
event_time = datetime.utcfromtimestamp(event.timestamp / 1e9).strftime(
"%H:%M:%S.%f"
)[:-3]
event_time = datetime.utcfromtimestamp(event.timestamp / 1e9).strftime("%H:%M:%S.%f")[:-3]
severity = event.attributes.get("severity", "info")
message = event.attributes.get("message", event.name)
@ -96,11 +89,7 @@ class ConsoleSpanProcessor(SpanProcessor):
}
msg_color = severity_colors.get(severity, COLORS["white"])
print(
f" {event_time} "
f"{msg_color}[{severity.upper()}] "
f"{message}{COLORS['reset']}"
)
print(f" {event_time} {msg_color}[{severity.upper()}] {message}{COLORS['reset']}")
if event.attributes:
for key, value in event.attributes.items():

View file

@ -101,14 +101,10 @@ class TelemetryAdapter(TelemetryDatasetMixin, Telemetry):
endpoint=self.config.otel_endpoint,
)
)
metric_provider = MeterProvider(
resource=resource, metric_readers=[metric_reader]
)
metric_provider = MeterProvider(resource=resource, metric_readers=[metric_reader])
metrics.set_meter_provider(metric_provider)
if TelemetrySink.SQLITE in self.config.sinks:
trace.get_tracer_provider().add_span_processor(
SQLiteSpanProcessor(self.config.sqlite_db_path)
)
trace.get_tracer_provider().add_span_processor(SQLiteSpanProcessor(self.config.sqlite_db_path))
if TelemetrySink.CONSOLE in self.config.sinks:
trace.get_tracer_provider().add_span_processor(ConsoleSpanProcessor())
@ -154,9 +150,7 @@ class TelemetryAdapter(TelemetryDatasetMixin, Telemetry):
timestamp=timestamp_ns,
)
else:
print(
f"Warning: No active span found for span_id {span_id}. Dropping event: {event}"
)
print(f"Warning: No active span found for span_id {span_id}. Dropping event: {event}")
def _get_or_create_counter(self, name: str, unit: str) -> metrics.Counter:
if name not in _GLOBAL_STORAGE["counters"]:
@ -181,21 +175,15 @@ class TelemetryAdapter(TelemetryDatasetMixin, Telemetry):
counter = self._get_or_create_counter(event.metric, event.unit)
counter.add(event.value, attributes=event.attributes)
elif isinstance(event.value, float):
up_down_counter = self._get_or_create_up_down_counter(
event.metric, event.unit
)
up_down_counter = self._get_or_create_up_down_counter(event.metric, event.unit)
up_down_counter.add(event.value, attributes=event.attributes)
def _get_or_create_up_down_counter(
self, name: str, unit: str
) -> metrics.UpDownCounter:
def _get_or_create_up_down_counter(self, name: str, unit: str) -> metrics.UpDownCounter:
if name not in _GLOBAL_STORAGE["up_down_counters"]:
_GLOBAL_STORAGE["up_down_counters"][name] = (
self.meter.create_up_down_counter(
name=name,
unit=unit,
description=f"UpDownCounter for {name}",
)
_GLOBAL_STORAGE["up_down_counters"][name] = self.meter.create_up_down_counter(
name=name,
unit=unit,
description=f"UpDownCounter for {name}",
)
return _GLOBAL_STORAGE["up_down_counters"][name]

View file

@ -87,13 +87,9 @@ class CodeExecutor:
scripts = req.scripts
for i in range(len(scripts) - 1):
if req.only_last_cell_stdouterr:
scripts[i] = STDOUTERR_SINK_WRAPPER_TEMPLATE.format(
code=textwrap.indent(scripts[i], " " * 4)
)
scripts[i] = STDOUTERR_SINK_WRAPPER_TEMPLATE.format(code=textwrap.indent(scripts[i], " " * 4))
if req.only_last_cell_fail:
scripts[i] = TRYEXCEPT_WRAPPER_TEMPLATE.format(
code=textwrap.indent(scripts[i], " " * 4)
)
scripts[i] = TRYEXCEPT_WRAPPER_TEMPLATE.format(code=textwrap.indent(scripts[i], " " * 4))
# Seeds prefix:
seed = req.seed
@ -190,7 +186,7 @@ def execute_subprocess_request(request, ctx: CodeExecutionContext):
if request["type"] == "matplotlib":
return process_matplotlib_response(request, ctx.matplotlib_dump_dir)
else:
raise Exception(f'Unrecognised network request type: {request["type"]}')
raise Exception(f"Unrecognised network request type: {request['type']}")
def do_subprocess(*, cmd: list, env: dict, ctx: CodeExecutionContext):

View file

@ -59,9 +59,7 @@ class CodeInterpreterToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime):
)
]
async def invoke_tool(
self, tool_name: str, kwargs: Dict[str, Any]
) -> ToolInvocationResult:
async def invoke_tool(self, tool_name: str, kwargs: Dict[str, Any]) -> ToolInvocationResult:
script = kwargs["code"]
req = CodeExecutionRequest(scripts=[script])
res = self.code_executor.execute(req)

View file

@ -39,9 +39,7 @@ log = logging.getLogger(__name__)
def make_random_string(length: int = 8):
return "".join(
secrets.choice(string.ascii_letters + string.digits) for _ in range(length)
)
return "".join(secrets.choice(string.ascii_letters + string.digits) for _ in range(length))
class MemoryToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime, RAGToolRuntime):
@ -120,9 +118,7 @@ class MemoryToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime, RAGToolRuntime):
return RAGQueryResult(content=None)
# sort by score
chunks, scores = zip(
*sorted(zip(chunks, scores), key=lambda x: x[1], reverse=True)
)
chunks, scores = zip(*sorted(zip(chunks, scores), key=lambda x: x[1], reverse=True))
tokens = 0
picked = []
@ -169,9 +165,7 @@ class MemoryToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime, RAGToolRuntime):
),
]
async def invoke_tool(
self, tool_name: str, kwargs: Dict[str, Any]
) -> ToolInvocationResult:
async def invoke_tool(self, tool_name: str, kwargs: Dict[str, Any]) -> ToolInvocationResult:
raise RuntimeError(
"This toolgroup should not be called generically but only through specific methods of the RAGToolRuntime protocol"
)

View file

@ -11,9 +11,7 @@ from llama_stack.providers.datatypes import Api, ProviderSpec
from .config import ChromaInlineImplConfig
async def get_provider_impl(
config: ChromaInlineImplConfig, deps: Dict[Api, ProviderSpec]
):
async def get_provider_impl(config: ChromaInlineImplConfig, deps: Dict[Api, ProviderSpec]):
from llama_stack.providers.remote.vector_io.chroma.chroma import (
ChromaVectorIOAdapter,
)

View file

@ -13,9 +13,7 @@ from .config import FaissImplConfig
async def get_provider_impl(config: FaissImplConfig, deps: Dict[Api, ProviderSpec]):
from .faiss import FaissVectorIOImpl
assert isinstance(
config, FaissImplConfig
), f"Unexpected config type: {type(config)}"
assert isinstance(config, FaissImplConfig), f"Unexpected config type: {type(config)}"
impl = FaissVectorIOImpl(config, deps[Api.inference])
await impl.initialize()

View file

@ -59,10 +59,7 @@ class FaissIndex(EmbeddingIndex):
if stored_data:
data = json.loads(stored_data)
self.chunk_by_index = {
int(k): Chunk.model_validate_json(v)
for k, v in data["chunk_by_index"].items()
}
self.chunk_by_index = {int(k): Chunk.model_validate_json(v) for k, v in data["chunk_by_index"].items()}
buffer = io.BytesIO(base64.b64decode(data["faiss_index"]))
self.index = faiss.deserialize_index(np.loadtxt(buffer, dtype=np.uint8))
@ -75,9 +72,7 @@ class FaissIndex(EmbeddingIndex):
buffer = io.BytesIO()
np.savetxt(buffer, np_index)
data = {
"chunk_by_index": {
k: v.model_dump_json() for k, v in self.chunk_by_index.items()
},
"chunk_by_index": {k: v.model_dump_json() for k, v in self.chunk_by_index.items()},
"faiss_index": base64.b64encode(buffer.getvalue()).decode("utf-8"),
}
@ -92,13 +87,9 @@ class FaissIndex(EmbeddingIndex):
async def add_chunks(self, chunks: List[Chunk], embeddings: NDArray):
# Add dimension check
embedding_dim = (
embeddings.shape[1] if len(embeddings.shape) > 1 else embeddings.shape[0]
)
embedding_dim = embeddings.shape[1] if len(embeddings.shape) > 1 else embeddings.shape[0]
if embedding_dim != self.index.d:
raise ValueError(
f"Embedding dimension mismatch. Expected {self.index.d}, got {embedding_dim}"
)
raise ValueError(f"Embedding dimension mismatch. Expected {self.index.d}, got {embedding_dim}")
indexlen = len(self.chunk_by_index)
for i, chunk in enumerate(chunks):
@ -109,12 +100,8 @@ class FaissIndex(EmbeddingIndex):
# Save updated index
await self._save_index()
async def query(
self, embedding: NDArray, k: int, score_threshold: float
) -> QueryChunksResponse:
distances, indices = self.index.search(
embedding.reshape(1, -1).astype(np.float32), k
)
async def query(self, embedding: NDArray, k: int, score_threshold: float) -> QueryChunksResponse:
distances, indices = self.index.search(embedding.reshape(1, -1).astype(np.float32), k)
chunks = []
scores = []
@ -145,9 +132,7 @@ class FaissVectorIOImpl(VectorIO, VectorDBsProtocolPrivate):
vector_db = VectorDB.model_validate_json(vector_db_data)
index = VectorDBWithIndex(
vector_db,
await FaissIndex.create(
vector_db.embedding_dimension, self.kvstore, vector_db.identifier
),
await FaissIndex.create(vector_db.embedding_dimension, self.kvstore, vector_db.identifier),
self.inference_api,
)
self.cache[vector_db.identifier] = index
@ -169,9 +154,7 @@ class FaissVectorIOImpl(VectorIO, VectorDBsProtocolPrivate):
# Store in cache
self.cache[vector_db.identifier] = VectorDBWithIndex(
vector_db=vector_db,
index=await FaissIndex.create(
vector_db.embedding_dimension, self.kvstore, vector_db.identifier
),
index=await FaissIndex.create(vector_db.embedding_dimension, self.kvstore, vector_db.identifier),
inference_api=self.inference_api,
)
@ -195,9 +178,7 @@ class FaissVectorIOImpl(VectorIO, VectorDBsProtocolPrivate):
) -> None:
index = self.cache.get(vector_db_id)
if index is None:
raise ValueError(
f"Vector DB {vector_db_id} not found. found: {self.cache.keys()}"
)
raise ValueError(f"Vector DB {vector_db_id} not found. found: {self.cache.keys()}")
await index.insert_chunks(chunks)