"""Auto-typing / rule-batch implementation moved to core.
This module contains the rule-based auto-typing logic formerly located in
`swcstudio.gui.rule_batch_processor`. It is kept in `swcstudio.core` so both GUI
and CLI can use the same implementation.
"""
from __future__ import annotations
from dataclasses import dataclass
from pathlib import Path
from typing import Any
import zipfile
import math
import numpy as np
from swcstudio.core.config import load_feature_config, merge_config, save_feature_config
from swcstudio.core.reporting import (
auto_typing_log_path_for_file,
format_auto_typing_report_text,
operation_output_dir_for_folder,
operation_output_path_for_file,
operation_report_path_for_file,
operation_report_path_for_folder,
timestamp_slug,
write_text_report,
)
TOOL = "batch_processing"
FEATURE = "auto_typing"
_DEFAULT_CFG: dict[str, Any] | None = None
def _default_rules_config() -> dict[str, Any]:
return {
"class_labels": {"1": "soma", "2": "axon", "3": "basal", "4": "apical"},
"branch_score_weights": {
"axon": {
"path": 0.14,
"radial": 0.12,
"root_path": 0.18,
"root_radial": 0.12,
"radius": 0.12,
"branch": 0.04,
"persistence": 0.16,
"taper": 0.08,
"symmetry": 0.04,
"up": 0.02,
"prior": 0.04,
},
"apical": {
"z": 0.20,
"up": 0.20,
"path": 0.12,
"root_path": 0.16,
"radius": 0.12,
"branch": 0.10,
"taper": 0.06,
"symmetry": 0.04,
"prior": 0.14,
},
"basal": {
"z": 0.12,
"up": 0.12,
"branch": 0.16,
"radius": 0.14,
"path": 0.08,
"root_path": 0.20,
"root_radial": 0.12,
"persistence": 0.08,
"taper": 0.10,
"symmetry": 0.08,
"prior": 0.08,
},
},
"feature_windows": {
"terminal_window_nodes": 3,
},
"segmenting": {"max_chunk_path": 180.0},
"ml_blend": 0.28,
"ml_base_weight": 0.72,
"seed_prior_threshold": 0.55,
"assign_missing": {"min_score": 0.58, "min_gain": -0.06},
"smoothing": {"maj_fraction": 0.67, "flip_margin": 0.10, "continuity_margin": 0.02},
"refinement": {
"iterations": 2,
"parent_weight": 0.14,
"child_weight": 0.18,
"island_max_path": 36.0,
"island_relative_max": 0.35,
"island_flip_margin": 0.14,
},
"soma_child_prior": {
"branch_weight": 0.38,
"branch_boost": 0.16,
"propagation_weight": 0.30,
"score_weights": {
"axon": {
"path": 0.20,
"radial": 0.18,
"size": 0.18,
"radius": 0.10,
"branch": 0.08,
"persistence": 0.14,
"taper": 0.06,
"symmetry": 0.04,
"prior": 0.12,
},
"apical": {
"z": 0.22,
"up": 0.22,
"path": 0.14,
"size": 0.12,
"radius": 0.12,
"branch": 0.10,
"taper": 0.06,
"symmetry": 0.04,
"prior": 0.16,
},
"basal": {
"path": 0.20,
"radial": 0.18,
"radius": 0.14,
"branch": 0.12,
"z": 0.10,
"up": 0.10,
"persistence": 0.06,
"taper": 0.08,
"symmetry": 0.10,
"prior": 0.16,
},
},
},
"propagation_weights": {
"self": 0.35,
"parent": 0.35,
"children": 0.20,
"branch_prior": 0.30,
"iterations": 4,
},
"radius": {"copy_parent_if_zero": True},
"constraints": {
"inherit_primary_subtree": True,
"single_axon": True,
"single_apical": True,
"axon_primary_min_score": 0.42,
"apical_primary_min_score": 0.42,
"far_basal_distance_um": 500.0,
"far_basal_penalty": 0.22,
"thin_axon_max_base_radius_um": 1.0,
"thin_axon_bonus": 0.10,
},
"notes": (
"This JSON controls the auto-labeling behavior "
"(weights, thresholds, and options), including hard primary-subtree inheritance, "
"single-axon/apical constraints, path-persistence / terminal-taper features, "
"and topology-aware refinement. Edit carefully."
),
}
def _load_cfg() -> dict[str, Any]:
global _DEFAULT_CFG
if _DEFAULT_CFG is not None:
return dict(_DEFAULT_CFG)
feature_cfg = load_feature_config(TOOL, FEATURE, default={})
rules_cfg = feature_cfg.get("rules", feature_cfg if "feature" not in feature_cfg else {})
_DEFAULT_CFG = merge_config(_default_rules_config(), rules_cfg)
return dict(_DEFAULT_CFG)
[docs]
def get_config() -> dict:
"""Return the active configuration dict (loaded from JSON if available)."""
return _load_cfg()
[docs]
def save_config(cfg: dict) -> None:
"""Save rule settings into the batch auto-typing feature config."""
global _DEFAULT_CFG
feature_cfg = load_feature_config(TOOL, FEATURE, default={})
updated_cfg = merge_config(feature_cfg, {"rules": cfg})
save_feature_config(TOOL, FEATURE, updated_cfg)
_DEFAULT_CFG = merge_config(_default_rules_config(), cfg)
[docs]
@dataclass
class RuleBatchOptions:
soma: bool = False
axon: bool = False
apic: bool = False
basal: bool = False
rad: bool = False
zip_output: bool = False
[docs]
@dataclass
class RuleBatchResult:
folder: str
out_dir: str
zip_path: str | None
files_total: int
files_processed: int
files_failed: int
total_nodes: int
total_type_changes: int
total_radius_changes: int
failures: list[str]
per_file: list[str]
log_path: str | None
[docs]
@dataclass
class RuleFileResult:
input_file: str
output_file: str | None
nodes_total: int
type_changes: int
radius_changes: int
out_type_counts: dict[int, int]
failures: list[str]
change_details: list[str]
log_path: str | None
headers: list[str]
rows: list[dict[str, Any]]
types: list[int]
radii: list[float]
def _parse_swc(path: Path) -> tuple[list[str], list[dict[str, Any]]]:
headers: list[str] = []
rows: list[dict[str, Any]] = []
with path.open("r", encoding="utf-8", errors="ignore") as fh:
for line in fh:
s = line.strip()
if not s:
continue
if s.startswith("#"):
headers.append(line.rstrip("\n"))
continue
parts = s.split()
if len(parts) < 7:
continue
try:
rid = int(float(parts[0]))
rtype = int(float(parts[1]))
x = float(parts[2])
y = float(parts[3])
z = float(parts[4])
radius = float(parts[5])
parent = int(float(parts[6]))
except Exception:
continue
rows.append(
{
"id": rid,
"type": rtype,
"x": x,
"y": y,
"z": z,
"radius": radius,
"parent": parent,
}
)
return headers, rows
def _build_topology(rows: list[dict[str, Any]]) -> tuple[list[int | None], list[list[int]], list[int]]:
n = len(rows)
id_to_idx = {int(row["id"]): i for i, row in enumerate(rows)}
parent_idx: list[int | None] = [None] * n
children: list[list[int]] = [[] for _ in range(n)]
for i, row in enumerate(rows):
pidx = id_to_idx.get(int(row["parent"]))
parent_idx[i] = pidx
if pidx is not None:
children[pidx].append(i)
roots = [i for i, row in enumerate(rows) if int(row["parent"]) == -1 or parent_idx[i] is None]
roots.sort(key=lambda idx: int(rows[idx]["id"]))
order: list[int] = []
seen = set()
queue = list(roots)
while queue:
idx = queue.pop(0)
if idx in seen:
continue
seen.add(idx)
order.append(idx)
kids = sorted(children[idx], key=lambda k: int(rows[k]["id"]))
queue.extend(kids)
for i in sorted(range(n), key=lambda idx: int(rows[idx]["id"])):
if i not in seen:
order.append(i)
return parent_idx, children, order
def _normalize_map(vals: dict[int, float]) -> dict[int, float]:
if not vals:
return {}
lo = min(vals.values())
hi = max(vals.values())
if hi <= lo:
return {k: 0.5 for k in vals}
scale = hi - lo
return {k: (v - lo) / scale for k, v in vals.items()}
def _iter_branch_segment(
start: int,
rows: list[dict[str, Any]],
children: list[list[int]],
parent_idx: list[int | None],
max_chunk_path: float,
) -> list[int]:
"""Return a linear segment until a leaf, bifurcation, or chunk limit is reached."""
out: list[int] = []
cur = start
seen = set()
chunk_path = 0.0
while cur not in seen:
seen.add(cur)
out.append(cur)
kids = children[cur]
if len(kids) != 1:
break
nxt = kids[0]
pidx = parent_idx[nxt]
if pidx is None:
break
dx = float(rows[nxt]["x"]) - float(rows[pidx]["x"])
dy = float(rows[nxt]["y"]) - float(rows[pidx]["y"])
dz = float(rows[nxt]["z"]) - float(rows[pidx]["z"])
seg_len = math.sqrt(dx * dx + dy * dy + dz * dz)
if out and chunk_path + seg_len > max_chunk_path:
break
chunk_path += seg_len
cur = nxt
return out
def _branch_partition(
rows: list[dict[str, Any]],
parent_idx: list[int | None],
children: list[list[int]],
types: list[int],
) -> tuple[dict[int, list[int]], dict[int, int], list[int]]:
"""Partition morphology into branch segments anchored at roots/bifurcations."""
n = len(rows)
roots = [i for i, p in enumerate(parent_idx) if p is None]
soma_roots = [i for i in roots if int(types[i]) == 1]
anchors = soma_roots if soma_roots else roots
max_chunk_path = float(_load_cfg().get("segmenting", {}).get("max_chunk_path", 180.0))
node_branch = [-1] * n
branch_nodes: dict[int, list[int]] = {}
branch_anchor: dict[int, int] = {}
bid = 0
seen_starts: set[int] = set()
pending: list[tuple[int, int]] = []
for anchor in anchors:
kids = sorted(children[anchor], key=lambda i: int(rows[i]["id"]))
if not kids and int(types[anchor]) != 1:
pending.append((parent_idx[anchor] if parent_idx[anchor] is not None else anchor, anchor))
continue
for child in kids:
pending.append((anchor, child))
while pending:
anchor, start = pending.pop(0)
if start in seen_starts:
continue
seen_starts.add(start)
nodes = _iter_branch_segment(start, rows, children, parent_idx, max_chunk_path)
if not nodes:
continue
branch_nodes[bid] = nodes
branch_anchor[bid] = anchor
for x in nodes:
if node_branch[x] == -1:
node_branch[x] = bid
tail = nodes[-1]
kids = sorted(children[tail], key=lambda i: int(rows[i]["id"]))
if len(kids) == 1:
pending.append((tail, kids[0]))
else:
for child in kids:
pending.append((tail, child))
bid += 1
for i in range(n):
if node_branch[i] != -1 or int(types[i]) == 1:
continue
anchor = parent_idx[i] if parent_idx[i] is not None else i
nodes = _iter_branch_segment(i, rows, children, parent_idx, max_chunk_path)
branch_nodes[bid] = nodes
branch_anchor[bid] = anchor
for x in nodes:
if node_branch[x] == -1:
node_branch[x] = bid
tail = nodes[-1]
kids = sorted(children[tail], key=lambda j: int(rows[j]["id"]))
if len(kids) == 1:
pending.append((tail, kids[0]))
else:
for child in kids:
pending.append((tail, child))
bid += 1
return branch_nodes, branch_anchor, node_branch
def _compute_root_metrics(
rows: list[dict[str, Any]],
parent_idx: list[int | None],
children: list[list[int]],
order: list[int],
) -> tuple[list[float], list[float], list[int], int | None]:
n = len(rows)
roots = [i for i, p in enumerate(parent_idx) if p is None]
root_idx = roots[0] if roots else None
path_from_root = [0.0] * n
branch_order = [0] * n
for i in order:
pidx = parent_idx[i]
if pidx is None:
continue
dx = float(rows[i]["x"]) - float(rows[pidx]["x"])
dy = float(rows[i]["y"]) - float(rows[pidx]["y"])
dz = float(rows[i]["z"]) - float(rows[pidx]["z"])
path_from_root[i] = path_from_root[pidx] + math.sqrt(dx * dx + dy * dy + dz * dz)
branch_order[i] = branch_order[pidx] + (1 if len(children[pidx]) > 1 else 0)
if root_idx is None:
return path_from_root, [0.0] * n, branch_order, None
rx = float(rows[root_idx]["x"])
ry = float(rows[root_idx]["y"])
rz = float(rows[root_idx]["z"])
radial_from_root = [
math.sqrt(
(float(row["x"]) - rx) ** 2
+ (float(row["y"]) - ry) ** 2
+ (float(row["z"]) - rz) ** 2
)
for row in rows
]
return path_from_root, radial_from_root, branch_order, root_idx
def _window_mean(vals: list[float], count: int) -> float:
if not vals:
return 0.0
take = max(1, min(int(count), len(vals)))
return float(sum(vals[:take]) / take)
def _terminal_taper_ratio(nodes: list[int], rows: list[dict[str, Any]], *, window_nodes: int) -> float:
if not nodes:
return 1.0
radii = [max(0.0, float(rows[idx]["radius"])) for idx in nodes]
count = max(1, min(int(window_nodes), len(radii)))
base = _window_mean(radii, count)
tail = _window_mean(list(reversed(radii)), count)
if base <= 1e-9:
return 1.0
return max(0.0, float(tail / base))
def _terminal_up_alignment(nodes: list[int], rows: list[dict[str, Any]]) -> float:
if len(nodes) <= 1:
return 0.5
start = rows[nodes[0]]
end = rows[nodes[-1]]
dx = float(end["x"]) - float(start["x"])
dy = float(end["y"]) - float(start["y"])
dz = float(end["z"]) - float(start["z"])
dist = math.sqrt(dx * dx + dy * dy + dz * dz)
if dist <= 1e-9:
return 0.5
return max(0.0, min(1.0, (dz / dist + 1.0) * 0.5))
def _directional_persistence(path_length: float, euclidean_distance: float) -> float:
if path_length <= 1e-9 or euclidean_distance <= 1e-9:
return 0.5
return max(0.0, min(1.0, float(euclidean_distance / path_length)))
def _branch_symmetry(anchor: int, rows: list[dict[str, Any]], children: list[list[int]]) -> float:
kids = children[anchor]
if len(kids) <= 1:
return 0.5
child_radii = [max(0.0, float(rows[idx]["radius"])) for idx in kids]
if not child_radii:
return 0.5
med = float(np.median(child_radii))
if med <= 1e-9:
return 0.5
mad = float(np.mean([abs(val - med) for val in child_radii]))
symmetry = 1.0 - min(1.0, mad / med)
return max(0.0, min(1.0, symmetry))
def _soma_child_owners(
rows: list[dict[str, Any]],
parent_idx: list[int | None],
children: list[list[int]],
types: list[int],
) -> tuple[int | None, dict[int, list[int]], list[int | None]]:
soma_roots = [i for i, p in enumerate(parent_idx) if p is None and int(types[i]) == 1]
root_idx = soma_roots[0] if soma_roots else next((i for i, p in enumerate(parent_idx) if p is None), None)
owners: list[int | None] = [None] * len(rows)
child_nodes: dict[int, list[int]] = {}
if root_idx is None:
return None, child_nodes, owners
for child in sorted(children[root_idx], key=lambda i: int(rows[i]["id"])):
stack = [child]
child_nodes[child] = []
while stack:
idx = stack.pop()
if owners[idx] is not None:
continue
owners[idx] = child
child_nodes[child].append(idx)
stack.extend(children[idx])
return root_idx, child_nodes, owners
def _assign_soma_child_subtrees(
rows: list[dict[str, Any]],
parent_idx: list[int | None],
children: list[list[int]],
types: list[int],
enabled_neurites: set[int],
path_from_root: list[float],
radial_from_root: list[float],
) -> tuple[dict[int, int], dict[int, dict[int, float]], list[int | None]]:
soma_idx, child_nodes, node_child_owner = _soma_child_owners(rows, parent_idx, children, types)
if soma_idx is None or not child_nodes or not enabled_neurites:
return {}, {}, node_child_owner
node_count = {child: float(len(nodes)) for child, nodes in child_nodes.items()}
path_max = {child: max(path_from_root[i] for i in nodes) for child, nodes in child_nodes.items()}
radial_max = {child: max(radial_from_root[i] for i in nodes) for child, nodes in child_nodes.items()}
mean_radius = {
child: sum(float(rows[i]["radius"]) for i in nodes) / max(1, len(nodes))
for child, nodes in child_nodes.items()
}
branch_density = {
child: sum(1 for i in nodes if len(children[i]) > 1) / max(1, len(nodes))
for child, nodes in child_nodes.items()
}
soma_z = float(rows[soma_idx]["z"])
z_max_rel = {
child: max(float(rows[i]["z"]) - soma_z for i in nodes)
for child, nodes in child_nodes.items()
}
terminal_window = int(_load_cfg().get("feature_windows", {}).get("terminal_window_nodes", 3))
persistence = {
child: _directional_persistence(path_max.get(child, 0.0), radial_max.get(child, 0.0))
for child in child_nodes
}
taper_ratio = {
child: _terminal_taper_ratio(nodes, rows, window_nodes=terminal_window)
for child, nodes in child_nodes.items()
}
up_alignment = {
child: _terminal_up_alignment(nodes, rows)
for child, nodes in child_nodes.items()
}
symmetry = {
child: _branch_symmetry(child, rows, children)
for child in child_nodes
}
existing_ratio: dict[tuple[int, int], float] = {}
for child, nodes in child_nodes.items():
for cls in enabled_neurites:
existing_ratio[(child, cls)] = sum(1 for i in nodes if int(types[i]) == cls) / max(1, len(nodes))
n_size = _normalize_map(node_count)
n_path = _normalize_map(path_max)
n_radial = _normalize_map(radial_max)
n_radius = _normalize_map(mean_radius)
n_branch = _normalize_map(branch_density)
n_z = _normalize_map(z_max_rel)
n_persistence = dict(persistence)
n_taper = {child: 1.0 - min(1.0, max(0.0, taper_ratio.get(child, 1.0))) for child in child_nodes}
n_taper_axon = {child: min(1.0, max(0.0, taper_ratio.get(child, 1.0))) for child in child_nodes}
n_up = dict(up_alignment)
n_symmetry = dict(symmetry)
cfg = _load_cfg().get("soma_child_prior", {})
constraints = _load_cfg().get("constraints", {})
far_basal_distance_um = float(constraints.get("far_basal_distance_um", 500.0))
far_basal_penalty = float(constraints.get("far_basal_penalty", 0.22))
thin_axon_max_base_radius = float(constraints.get("thin_axon_max_base_radius_um", 1.0))
thin_axon_bonus = float(constraints.get("thin_axon_bonus", 0.10))
weights = cfg.get("score_weights", {})
child_scores: dict[int, dict[int, float]] = {}
for child in child_nodes:
sc: dict[int, float] = {}
for cls in enabled_neurites:
prior = existing_ratio.get((child, cls), 0.0)
if cls == 2:
w = weights.get("axon", {})
s = (
w.get("path", 0.20) * n_path.get(child, 0.5)
+ w.get("radial", 0.18) * n_radial.get(child, 0.5)
+ w.get("size", 0.18) * n_size.get(child, 0.5)
+ w.get("radius", 0.10) * (1.0 - n_radius.get(child, 0.5))
+ w.get("branch", 0.08) * (1.0 - n_branch.get(child, 0.5))
+ w.get("persistence", 0.14) * n_persistence.get(child, 0.5)
+ w.get("taper", 0.06) * n_taper_axon.get(child, 0.5)
+ w.get("symmetry", 0.04) * (1.0 - n_symmetry.get(child, 0.5))
+ w.get("prior", 0.12) * prior
)
if float(rows[child]["radius"]) <= thin_axon_max_base_radius:
s += thin_axon_bonus
elif cls == 4:
w = weights.get("apical", {})
s = (
w.get("z", 0.22) * n_z.get(child, 0.5)
+ w.get("up", 0.22) * n_up.get(child, 0.5)
+ w.get("path", 0.14) * n_path.get(child, 0.5)
+ w.get("size", 0.12) * n_size.get(child, 0.5)
+ w.get("radius", 0.12) * n_radius.get(child, 0.5)
+ w.get("branch", 0.10) * n_branch.get(child, 0.5)
+ w.get("taper", 0.06) * n_taper.get(child, 0.5)
+ w.get("symmetry", 0.04) * n_symmetry.get(child, 0.5)
+ w.get("prior", 0.16) * prior
)
else:
w = weights.get("basal", {})
s = (
w.get("path", 0.20) * (1.0 - n_path.get(child, 0.5))
+ w.get("radial", 0.18) * (1.0 - n_radial.get(child, 0.5))
+ w.get("radius", 0.14) * n_radius.get(child, 0.5)
+ w.get("branch", 0.12) * n_branch.get(child, 0.5)
+ w.get("z", 0.10) * (1.0 - n_z.get(child, 0.5))
+ w.get("up", 0.10) * (1.0 - n_up.get(child, 0.5))
+ w.get("persistence", 0.06) * (1.0 - n_persistence.get(child, 0.5))
+ w.get("taper", 0.08) * n_taper.get(child, 0.5)
+ w.get("symmetry", 0.10) * n_symmetry.get(child, 0.5)
+ w.get("prior", 0.16) * prior
)
if max(path_max.get(child, 0.0), radial_max.get(child, 0.0)) > far_basal_distance_um:
s -= far_basal_penalty
sc[cls] = s
child_scores[child] = sc
child_class = _assign_branches(child_nodes, child_scores, enabled_neurites)
return child_class, child_scores, node_child_owner
def _pick_best_class(scores: dict[int, float], allowed: set[int]) -> int | None:
if not allowed:
return None
ordered = sorted(allowed)
return max(ordered, key=lambda cls: (float(scores.get(cls, float("-inf"))), -int(cls)))
def _enforce_primary_subtree_constraints(
child_scores: dict[int, dict[int, float]],
enabled_neurites: set[int],
) -> dict[int, int]:
if not child_scores or not enabled_neurites:
return {}
cfg = _load_cfg().get("constraints", {})
inherit_primary_subtree = bool(cfg.get("inherit_primary_subtree", True))
single_axon = bool(cfg.get("single_axon", True))
single_apical = bool(cfg.get("single_apical", True))
axon_primary_min = float(cfg.get("axon_primary_min_score", 0.42))
apical_primary_min = float(cfg.get("apical_primary_min_score", 0.42))
child_ids = sorted(child_scores)
out = _assign_branches({child: [child] for child in child_ids}, child_scores, enabled_neurites)
if not inherit_primary_subtree:
return out
axon_owner: int | None = None
if single_axon and 2 in enabled_neurites and child_ids:
best = max(child_ids, key=lambda child: float(child_scores.get(child, {}).get(2, float("-inf"))))
if float(child_scores.get(best, {}).get(2, float("-inf"))) >= axon_primary_min:
axon_owner = best
apical_owner: int | None = None
if single_apical and 4 in enabled_neurites and child_ids:
remaining = [child for child in child_ids if child != axon_owner]
if remaining:
best = max(remaining, key=lambda child: float(child_scores.get(child, {}).get(4, float("-inf"))))
if float(child_scores.get(best, {}).get(4, float("-inf"))) >= apical_primary_min:
apical_owner = best
fallback_shared = set(enabled_neurites)
if len(fallback_shared - {2, 4}) <= 0:
fallback_shared = set(enabled_neurites)
for child in child_ids:
if child == axon_owner:
out[child] = 2
continue
if child == apical_owner:
out[child] = 4
continue
allowed = set(enabled_neurites)
if single_axon and axon_owner is not None and 2 in allowed and len(allowed - {2}) >= 1:
allowed.discard(2)
if single_apical and apical_owner is not None and 4 in allowed and len(allowed - {4}) >= 1:
allowed.discard(4)
if not allowed:
allowed = set(fallback_shared)
picked = _pick_best_class(child_scores.get(child, {}), allowed)
if picked is not None:
out[child] = picked
return out
def _branch_scores(
rows: list[dict[str, Any]],
parent_idx: list[int | None],
children: list[list[int]],
types: list[int],
branch_nodes: dict[int, list[int]],
branch_anchor: dict[int, int],
enabled_neurites: set[int],
path_from_root: list[float],
radial_from_root: list[float],
node_child_owner: list[int | None],
child_class: dict[int, int],
child_scores: dict[int, dict[int, float]],
) -> tuple[dict[int, dict[int, float]], dict[int, tuple[float, ...]], dict[tuple[int, int], float]]:
x = [float(r["x"]) for r in rows]
y = [float(r["y"]) for r in rows]
z = [float(r["z"]) for r in rows]
rad = [float(r["radius"]) for r in rows]
path_len: dict[int, float] = {}
radial_extent: dict[int, float] = {}
mean_radius: dict[int, float] = {}
branchiness: dict[int, float] = {}
z_mean_rel: dict[int, float] = {}
root_path_mean: dict[int, float] = {}
root_radial_mean: dict[int, float] = {}
persistence: dict[int, float] = {}
taper_ratio: dict[int, float] = {}
up_alignment: dict[int, float] = {}
symmetry: dict[int, float] = {}
existing_ratio: dict[tuple[int, int], float] = {}
terminal_window = int(_load_cfg().get("feature_windows", {}).get("terminal_window_nodes", 3))
for bid, nodes in branch_nodes.items():
a = branch_anchor[bid]
ax, ay, az = x[a], y[a], z[a]
plen = 0.0
max_r = 0.0
bif = 0
for i in nodes:
p = parent_idx[i]
if p is not None:
dx = x[i] - x[p]
dy = y[i] - y[p]
dz = z[i] - z[p]
plen += math.sqrt(dx * dx + dy * dy + dz * dz)
dxa = x[i] - ax
dya = y[i] - ay
dza = z[i] - az
max_r = max(max_r, math.sqrt(dxa * dxa + dya * dya + dza * dza))
if len(children[i]) > 1:
bif += 1
path_len[bid] = plen
radial_extent[bid] = max_r
mean_radius[bid] = sum(rad[i] for i in nodes) / max(1, len(nodes))
branchiness[bid] = bif / max(1, len(nodes))
z_mean_rel[bid] = sum((z[i] - az) for i in nodes) / max(1, len(nodes))
root_path_mean[bid] = sum(path_from_root[i] for i in nodes) / max(1, len(nodes))
root_radial_mean[bid] = sum(radial_from_root[i] for i in nodes) / max(1, len(nodes))
persistence[bid] = _directional_persistence(plen, max_r)
taper_ratio[bid] = _terminal_taper_ratio(nodes, rows, window_nodes=terminal_window)
up_alignment[bid] = _terminal_up_alignment([a] + nodes, rows)
symmetry[bid] = _branch_symmetry(a, rows, children)
for cls in enabled_neurites:
c = sum(1 for i in nodes if int(types[i]) == cls)
existing_ratio[(bid, cls)] = c / max(1, len(nodes))
n_path = _normalize_map(path_len)
n_radial = _normalize_map(radial_extent)
n_radius = _normalize_map(mean_radius)
n_branch = _normalize_map(branchiness)
n_z = _normalize_map(z_mean_rel)
n_root_path = _normalize_map(root_path_mean)
n_root_radial = _normalize_map(root_radial_mean)
n_persistence = dict(persistence)
n_taper = {bid: 1.0 - min(1.0, max(0.0, taper_ratio.get(bid, 1.0))) for bid in branch_nodes}
n_taper_axon = {bid: min(1.0, max(0.0, taper_ratio.get(bid, 1.0))) for bid in branch_nodes}
n_up = dict(up_alignment)
n_symmetry = dict(symmetry)
scores: dict[int, dict[int, float]] = {}
features: dict[int, tuple[float, ...]] = {}
cfg = _load_cfg()
constraints = cfg.get("constraints", {})
far_basal_distance_um = float(constraints.get("far_basal_distance_um", 500.0))
far_basal_penalty = float(constraints.get("far_basal_penalty", 0.22))
thin_axon_max_base_radius = float(constraints.get("thin_axon_max_base_radius_um", 1.0))
thin_axon_bonus = float(constraints.get("thin_axon_bonus", 0.10))
weights = cfg.get("branch_score_weights", {})
child_prior_cfg = cfg.get("soma_child_prior", {})
child_branch_weight = float(child_prior_cfg.get("branch_weight", 0.38))
child_branch_boost = float(child_prior_cfg.get("branch_boost", 0.16))
for bid in branch_nodes:
features[bid] = (
n_path.get(bid, 0.5),
n_radial.get(bid, 0.5),
n_radius.get(bid, 0.5),
n_branch.get(bid, 0.5),
n_z.get(bid, 0.5),
n_root_path.get(bid, 0.5),
n_root_radial.get(bid, 0.5),
n_persistence.get(bid, 0.5),
n_taper.get(bid, 0.5),
n_up.get(bid, 0.5),
n_symmetry.get(bid, 0.5),
)
owner = next((node_child_owner[i] for i in branch_nodes[bid] if node_child_owner[i] is not None), None)
br_scores: dict[int, float] = {}
for cls in enabled_neurites:
prior = existing_ratio.get((bid, cls), 0.0)
if cls == 2:
w = weights.get("axon", {})
s = (
w.get("path", 0.14) * n_path.get(bid, 0.5)
+ w.get("radial", 0.12) * n_radial.get(bid, 0.5)
+ w.get("root_path", 0.18) * n_root_path.get(bid, 0.5)
+ w.get("root_radial", 0.12) * n_root_radial.get(bid, 0.5)
+ w.get("radius", 0.12) * (1.0 - n_radius.get(bid, 0.5))
+ w.get("branch", 0.04) * (1.0 - n_branch.get(bid, 0.5))
+ w.get("persistence", 0.16) * n_persistence.get(bid, 0.5)
+ w.get("taper", 0.08) * n_taper_axon.get(bid, 0.5)
+ w.get("symmetry", 0.04) * (1.0 - n_symmetry.get(bid, 0.5))
+ w.get("up", 0.02) * (1.0 - abs(0.5 - n_up.get(bid, 0.5)) * 2.0)
+ w.get("prior", 0.04) * prior
)
if float(rows[branch_nodes[bid][0]]["radius"]) <= thin_axon_max_base_radius:
s += thin_axon_bonus
elif cls == 4:
w = weights.get("apical", {})
s = (
w.get("z", 0.20) * n_z.get(bid, 0.5)
+ w.get("up", 0.20) * n_up.get(bid, 0.5)
+ w.get("path", 0.12) * n_path.get(bid, 0.5)
+ w.get("root_path", 0.16) * n_root_path.get(bid, 0.5)
+ w.get("radius", 0.12) * n_radius.get(bid, 0.5)
+ w.get("branch", 0.10) * n_branch.get(bid, 0.5)
+ w.get("taper", 0.06) * n_taper.get(bid, 0.5)
+ w.get("symmetry", 0.04) * n_symmetry.get(bid, 0.5)
+ w.get("prior", 0.14) * prior
)
else:
w = weights.get("basal", {})
s = (
w.get("z", 0.12) * (1.0 - n_z.get(bid, 0.5))
+ w.get("up", 0.12) * (1.0 - n_up.get(bid, 0.5))
+ w.get("branch", 0.16) * n_branch.get(bid, 0.5)
+ w.get("radius", 0.14) * n_radius.get(bid, 0.5)
+ w.get("path", 0.08) * (1.0 - n_path.get(bid, 0.5))
+ w.get("root_path", 0.20) * (1.0 - n_root_path.get(bid, 0.5))
+ w.get("root_radial", 0.12) * (1.0 - n_root_radial.get(bid, 0.5))
+ w.get("persistence", 0.08) * (1.0 - n_persistence.get(bid, 0.5))
+ w.get("taper", 0.10) * n_taper.get(bid, 0.5)
+ w.get("symmetry", 0.08) * n_symmetry.get(bid, 0.5)
+ w.get("prior", 0.08) * prior
)
if max(root_path_mean.get(bid, 0.0), root_radial_mean.get(bid, 0.0)) > far_basal_distance_um:
s -= far_basal_penalty
if owner is not None and owner in child_scores:
s += child_branch_weight * child_scores[owner].get(cls, 0.0)
if child_class.get(owner) == cls:
s += child_branch_boost
br_scores[cls] = s
scores[bid] = br_scores
return scores, features, existing_ratio
def _enforce_owner_labels_on_branches(
branch_class: dict[int, int],
branch_nodes: dict[int, list[int]],
node_child_owner: list[int | None],
child_class: dict[int, int],
) -> dict[int, int]:
if not branch_class or not child_class:
return branch_class
if not bool(_load_cfg().get("constraints", {}).get("inherit_primary_subtree", True)):
return branch_class
out = dict(branch_class)
for bid, nodes in branch_nodes.items():
owner = next((node_child_owner[i] for i in nodes if i < len(node_child_owner) and node_child_owner[i] is not None), None)
if owner is not None and owner in child_class:
out[bid] = int(child_class[owner])
return out
def _euclid_similarity(a: tuple[float, ...], b: tuple[float, ...]) -> float:
d2 = 0.0
for x, y in zip(a, b):
d = x - y
d2 += d * d
dist = math.sqrt(d2)
max_dist = math.sqrt(float(len(a)))
if max_dist <= 0:
return 0.5
sim = 1.0 - (dist / max_dist)
return max(0.0, min(1.0, sim))
def _ml_refine_scores(
scores: dict[int, dict[int, float]],
features: dict[int, tuple[float, ...]],
existing_ratio: dict[tuple[int, int], float],
enabled_neurites: set[int],
) -> dict[int, dict[int, float]]:
if not scores or not enabled_neurites:
return scores
classes = sorted(enabled_neurites)
branch_ids = sorted(features.keys())
if not branch_ids:
return scores
cfg = _load_cfg()
seed_map: dict[int, list[int]] = {c: [] for c in classes}
seed_prior_threshold = float(cfg.get("seed_prior_threshold", 0.55))
for bid in branch_ids:
priors = {c: existing_ratio.get((bid, c), 0.0) for c in classes}
best_c = max(classes, key=lambda c: priors[c])
if priors[best_c] >= seed_prior_threshold:
seed_map[best_c].append(bid)
for c in classes:
if seed_map[c]:
continue
best_bid = max(branch_ids, key=lambda b: scores.get(b, {}).get(c, -1e9))
seed_map[c].append(best_bid)
prototypes: dict[int, tuple[float, ...]] = {}
for c in classes:
seeds = seed_map[c]
if not seeds:
continue
dim = len(features[seeds[0]])
acc = [0.0] * dim
for b in seeds:
fv = features[b]
for i in range(dim):
acc[i] += fv[i]
n = float(len(seeds))
prototypes[c] = tuple(v / n for v in acc)
out: dict[int, dict[int, float]] = {}
for bid in branch_ids:
out[bid] = {}
fv = features[bid]
for c in classes:
base = scores.get(bid, {}).get(c, 0.0)
proto = prototypes.get(c)
if proto is None:
out[bid][c] = base
continue
sim = _euclid_similarity(fv, proto)
ml_blend = float(cfg.get("ml_blend", 0.28))
ml_base = float(cfg.get("ml_base_weight", 0.72))
out[bid][c] = ml_base * base + ml_blend * sim
return out
def _assign_branches(
branch_nodes: dict[int, list[int]],
scores: dict[int, dict[int, float]],
enabled_neurites: set[int],
) -> dict[int, int]:
if not branch_nodes or not enabled_neurites:
return {}
selected = sorted(enabled_neurites)
assign: dict[int, int] = {}
if len(selected) == 1:
only = selected[0]
for bid in branch_nodes:
assign[bid] = only
return assign
for bid in branch_nodes:
b_scores = scores.get(bid, {})
cls = max(selected, key=lambda c: b_scores.get(c, -1e9))
assign[bid] = cls
missing = [c for c in selected if c not in set(assign.values())]
if missing and len(branch_nodes) >= len(selected):
for need in missing:
best_bid = None
best_gain = -1e9
best_need_score = -1e9
for bid in branch_nodes:
cur = assign[bid]
cur_s = scores.get(bid, {}).get(cur, 0.0)
need_s = scores.get(bid, {}).get(need, 0.0)
gain = need_s - cur_s
if gain > best_gain:
best_gain = gain
best_need_score = need_s
best_bid = bid
assign_cfg = _load_cfg().get("assign_missing", {})
min_score = float(assign_cfg.get("min_score", 0.58))
min_gain = float(assign_cfg.get("min_gain", -0.06))
if best_bid is not None and (best_need_score >= min_score and best_gain >= min_gain):
assign[best_bid] = need
return assign
def _smooth_branch_labels(
branch_class: dict[int, int],
scores: dict[int, dict[int, float]],
branch_anchor: dict[int, int],
node_branch: list[int] | None = None,
) -> dict[int, int]:
if not branch_class:
return branch_class
out = dict(branch_class)
anchor_to_branches: dict[int, list[int]] = {}
for bid, a in branch_anchor.items():
anchor_to_branches.setdefault(a, []).append(bid)
smooth_cfg = _load_cfg().get("smoothing", {})
maj_frac = float(smooth_cfg.get("maj_fraction", 0.67))
flip_margin = float(smooth_cfg.get("flip_margin", 0.10))
continuity_margin = float(smooth_cfg.get("continuity_margin", 0.02))
for bid, cur_cls in list(out.items()):
anchor = branch_anchor.get(bid)
sibs = [s for s in anchor_to_branches.get(anchor, []) if s != bid]
if len(sibs) >= 2:
counts: dict[int, int] = {}
for s in sibs:
c = out.get(s)
if c is None:
continue
counts[c] = counts.get(c, 0) + 1
if counts:
maj_cls, maj_count = max(counts.items(), key=lambda kv: kv[1])
if maj_cls != cur_cls and maj_count / max(1, len(sibs)) >= maj_frac:
cur_score = scores.get(bid, {}).get(cur_cls, 0.0)
maj_score = scores.get(bid, {}).get(maj_cls, 0.0)
if cur_score - maj_score < flip_margin:
out[bid] = maj_cls
cur_cls = maj_cls
if node_branch is None or anchor is None or anchor < 0 or anchor >= len(node_branch):
continue
parent_bid = node_branch[anchor]
if parent_bid == -1 or parent_bid == bid or parent_bid not in out:
continue
parent_cls = out[parent_bid]
if parent_cls == cur_cls:
continue
cur_score = scores.get(bid, {}).get(cur_cls, 0.0)
parent_score = scores.get(bid, {}).get(parent_cls, 0.0)
if cur_score - parent_score < continuity_margin:
out[bid] = parent_cls
return out
def _branch_path_lengths(
rows: list[dict[str, Any]],
parent_idx: list[int | None],
branch_nodes: dict[int, list[int]],
) -> dict[int, float]:
lengths: dict[int, float] = {}
for bid, nodes in branch_nodes.items():
plen = 0.0
for i in nodes:
pidx = parent_idx[i]
if pidx is None:
continue
dx = float(rows[i]["x"]) - float(rows[pidx]["x"])
dy = float(rows[i]["y"]) - float(rows[pidx]["y"])
dz = float(rows[i]["z"]) - float(rows[pidx]["z"])
plen += math.sqrt(dx * dx + dy * dy + dz * dz)
lengths[bid] = plen
return lengths
def _branch_graph(
branch_anchor: dict[int, int],
node_branch: list[int],
) -> tuple[dict[int, int], dict[int, list[int]]]:
branch_parent: dict[int, int] = {}
branch_children: dict[int, list[int]] = {}
for bid, anchor in branch_anchor.items():
if anchor < 0 or anchor >= len(node_branch):
continue
parent_bid = node_branch[anchor]
if parent_bid == -1 or parent_bid == bid:
continue
branch_parent[bid] = parent_bid
branch_children.setdefault(parent_bid, []).append(bid)
return branch_parent, branch_children
def _neighbor_refine_scores(
scores: dict[int, dict[int, float]],
branch_class: dict[int, int],
branch_parent: dict[int, int],
branch_children: dict[int, list[int]],
) -> dict[int, dict[int, float]]:
cfg = _load_cfg().get("refinement", {})
parent_weight = float(cfg.get("parent_weight", 0.14))
child_weight = float(cfg.get("child_weight", 0.18))
out = {bid: dict(sc) for bid, sc in scores.items()}
for bid, sc in out.items():
parent = branch_parent.get(bid)
if parent in branch_class:
cls = branch_class[parent]
sc[cls] = sc.get(cls, 0.0) + parent_weight
kids = branch_children.get(bid, [])
if kids:
weight = child_weight / max(1, len(kids))
for child in kids:
cls = branch_class.get(child)
if cls is None:
continue
sc[cls] = sc.get(cls, 0.0) + weight
return out
def _refine_topology_branch_labels(
branch_class: dict[int, int],
scores: dict[int, dict[int, float]],
branch_parent: dict[int, int],
branch_children: dict[int, list[int]],
branch_lengths: dict[int, float],
) -> dict[int, int]:
if not branch_class:
return branch_class
cfg = _load_cfg().get("refinement", {})
iterations = max(1, int(cfg.get("iterations", 2)))
island_max_path = float(cfg.get("island_max_path", 36.0))
island_relative_max = float(cfg.get("island_relative_max", 0.35))
island_flip_margin = float(cfg.get("island_flip_margin", 0.14))
out = dict(branch_class)
for _ in range(iterations):
updated = dict(out)
for bid, cur_cls in out.items():
parent = branch_parent.get(bid)
if parent is None or parent not in out:
continue
parent_cls = out[parent]
if parent_cls == cur_cls:
continue
matching_children = [child for child in branch_children.get(bid, []) if out.get(child) == parent_cls]
if not matching_children:
continue
cur_score = scores.get(bid, {}).get(cur_cls, 0.0)
target_score = scores.get(bid, {}).get(parent_cls, 0.0)
branch_len = branch_lengths.get(bid, 0.0)
ref_len = max(
[branch_lengths.get(parent, 0.0)] + [branch_lengths.get(child, 0.0) for child in matching_children] + [1.0]
)
is_short_island = (
branch_len <= island_max_path
or branch_len <= island_relative_max * ref_len
or cur_score - target_score < island_flip_margin
)
if is_short_island:
updated[bid] = parent_cls
out = updated
return out
def _apply_rules(rows: list[dict[str, Any]], opts: RuleBatchOptions) -> tuple[list[int], list[float], int, int]:
orig_types = [int(row["type"]) for row in rows]
types = list(orig_types)
orig_radii = [float(row["radius"]) for row in rows]
radii = list(orig_radii)
parent_idx, children, order = _build_topology(rows)
path_from_root, radial_from_root, _, _ = _compute_root_metrics(rows, parent_idx, children, order)
if opts.soma:
for i, row in enumerate(rows):
if int(row["parent"]) == -1 and types[i] != 1:
types[i] = 1
enabled_neurites: set[int] = set()
if opts.axon:
enabled_neurites.add(2)
if opts.basal:
enabled_neurites.add(3)
if opts.apic:
enabled_neurites.add(4)
if enabled_neurites:
child_class, child_scores, node_child_owner = _assign_soma_child_subtrees(
rows,
parent_idx,
children,
types,
enabled_neurites,
path_from_root,
radial_from_root,
)
child_class = _enforce_primary_subtree_constraints(child_scores, enabled_neurites)
branch_nodes, branch_anchor, node_branch = _branch_partition(rows, parent_idx, children, types)
scores, features, existing_ratio = _branch_scores(
rows,
parent_idx,
children,
types,
branch_nodes,
branch_anchor,
enabled_neurites,
path_from_root,
radial_from_root,
node_child_owner,
child_class,
child_scores,
)
scores = _ml_refine_scores(scores, features, existing_ratio, enabled_neurites)
branch_class = _assign_branches(branch_nodes, scores, enabled_neurites)
branch_class = _smooth_branch_labels(branch_class, scores, branch_anchor, node_branch)
branch_class = _enforce_owner_labels_on_branches(branch_class, branch_nodes, node_child_owner, child_class)
branch_parent, branch_children = _branch_graph(branch_anchor, node_branch)
branch_lengths = _branch_path_lengths(rows, parent_idx, branch_nodes)
scores = _neighbor_refine_scores(scores, branch_class, branch_parent, branch_children)
branch_class = _assign_branches(branch_nodes, scores, enabled_neurites)
branch_class = _smooth_branch_labels(branch_class, scores, branch_anchor, node_branch)
branch_class = _enforce_owner_labels_on_branches(branch_class, branch_nodes, node_child_owner, child_class)
branch_class = _refine_topology_branch_labels(
branch_class,
scores,
branch_parent,
branch_children,
branch_lengths,
)
branch_class = _enforce_owner_labels_on_branches(branch_class, branch_nodes, node_child_owner, child_class)
for bid, nodes in branch_nodes.items():
cls = branch_class.get(bid)
if cls is None:
continue
for i in nodes:
if opts.soma and int(types[i]) == 1:
continue
types[i] = cls
if opts.soma:
for i, row in enumerate(rows):
if int(row["parent"]) == -1:
types[i] = 1
if opts.rad:
radius_cfg = _load_cfg().get("radius", {})
copy_parent = bool(radius_cfg.get("copy_parent_if_zero", True))
if copy_parent:
for idx in order:
pidx = parent_idx[idx]
if pidx is None:
continue
if radii[idx] <= 0 and radii[pidx] > 0:
radii[idx] = radii[pidx]
type_changes = sum(1 for old, new in zip(orig_types, types) if int(old) != int(new))
radius_changes = sum(1 for old, new in zip(orig_radii, radii) if float(old) != float(new))
return types, radii, type_changes, radius_changes
def _write_swc(path: Path, headers: list[str], rows: list[dict[str, Any]], types: list[int], radii: list[float]) -> None:
with path.open("w", encoding="utf-8") as fh:
for h in headers:
fh.write(f"{h}\n")
for i, row in enumerate(rows):
fh.write(
f"{int(row['id'])} {int(types[i])} "
f"{float(row['x']):.10g} {float(row['y']):.10g} {float(row['z']):.10g} "
f"{float(radii[i]):.10g} {int(row['parent'])}\n"
)
def _build_change_details(
file_name: str,
rows: list[dict[str, Any]],
orig_types: list[int],
new_types: list[int],
orig_radii: list[float],
new_radii: list[float],
) -> list[str]:
out: list[str] = []
type_changes = sum(1 for old, new in zip(orig_types, new_types) if int(old) != int(new))
radius_changes = sum(1 for old, new in zip(orig_radii, new_radii) if float(old) != float(new))
if type_changes <= 0 and radius_changes <= 0:
return out
out.append(f"[{file_name}]")
if type_changes > 0:
out.append("type_changes:")
for row, old_t, new_t in zip(rows, orig_types, new_types):
if int(old_t) != int(new_t):
out.append(
f" node_id={int(row['id'])}: old_type={int(old_t)} -> new_type={int(new_t)}"
)
if radius_changes > 0:
out.append("radius_changes:")
for row, old_r, new_r in zip(rows, orig_radii, new_radii):
if float(old_r) != float(new_r):
out.append(
f" node_id={int(row['id'])}: "
f"old_radius={float(old_r):.10g} -> new_radius={float(new_r):.10g}"
)
out.append("")
return out
def run_rule_file(
file_path: str,
opts: RuleBatchOptions,
*,
output_path: str | None = None,
write_output: bool = True,
write_log: bool = True,
) -> RuleFileResult:
in_path = Path(file_path)
headers, rows = _parse_swc(in_path)
if not rows:
raise ValueError(f"{in_path.name}: no valid SWC rows")
orig_types = [int(r["type"]) for r in rows]
orig_radii = [float(r["radius"]) for r in rows]
types, radii, type_changes, radius_changes = _apply_rules(rows, opts)
out_path: Path | None = None
run_timestamp = timestamp_slug()
if write_output:
out_path = (
Path(output_path)
if output_path
else operation_output_path_for_file(in_path, "auto_typing", timestamp=run_timestamp)
)
_write_swc(out_path, headers, rows, types, radii)
out_counts = {
1: sum(1 for t in types if int(t) == 1),
2: sum(1 for t in types if int(t) == 2),
3: sum(1 for t in types if int(t) == 3),
4: sum(1 for t in types if int(t) == 4),
}
change_details = _build_change_details(
in_path.name,
rows,
orig_types,
types,
orig_radii,
radii,
)
log_path: str | None = None
if write_log:
log_target = (
auto_typing_log_path_for_file(in_path)
if output_path
else operation_report_path_for_file(in_path, "auto_typing", timestamp=run_timestamp)
)
payload = {
"folder": str(in_path.parent),
"out_dir": str(out_path.parent if out_path is not None else in_path.parent),
"zip_path": None,
"files_total": 1,
"files_processed": 1,
"files_failed": 0,
"total_nodes": len(rows),
"total_type_changes": type_changes,
"total_radius_changes": radius_changes,
"failures": [],
"per_file": [
f"{in_path.name}: nodes={len(rows)}, type_changes={type_changes}, "
f"radius_changes={radius_changes}, out_types(soma/axon/basal/apic)="
f"{out_counts[1]}/{out_counts[2]}/{out_counts[3]}/{out_counts[4]}"
],
"change_details": change_details,
}
log_path = write_text_report(log_target, format_auto_typing_report_text(payload))
return RuleFileResult(
input_file=str(in_path),
output_file=str(out_path) if out_path is not None else None,
nodes_total=len(rows),
type_changes=type_changes,
radius_changes=radius_changes,
out_type_counts=out_counts,
failures=[],
change_details=change_details,
log_path=log_path,
headers=headers,
rows=rows,
types=types,
radii=radii,
)
def run_rule_batch(folder: str, opts: RuleBatchOptions) -> RuleBatchResult:
in_dir = Path(folder)
swc_files = sorted([p for p in in_dir.iterdir() if p.is_file() and p.suffix.lower() == ".swc"])
run_timestamp = timestamp_slug()
out_dir = operation_output_dir_for_folder(in_dir, "batch_auto_typing", timestamp=run_timestamp)
failures: list[str] = []
per_file: list[str] = []
change_details: list[str] = []
processed = 0
total_nodes = 0
total_type_changes = 0
total_radius_changes = 0
for swc_path in swc_files:
try:
headers, rows = _parse_swc(swc_path)
if not rows:
failures.append(f"{swc_path.name}: no valid SWC rows")
continue
orig_types = [int(r["type"]) for r in rows]
orig_radii = [float(r["radius"]) for r in rows]
types, radii, type_changes, radius_changes = _apply_rules(rows, opts)
out_path = operation_output_path_for_file(
swc_path,
"batch_auto_typing",
output_dir=out_dir,
timestamp=run_timestamp,
)
_write_swc(out_path, headers, rows, types, radii)
processed += 1
total_nodes += len(rows)
total_type_changes += type_changes
total_radius_changes += radius_changes
out_counts = {
1: sum(1 for t in types if int(t) == 1),
2: sum(1 for t in types if int(t) == 2),
3: sum(1 for t in types if int(t) == 3),
4: sum(1 for t in types if int(t) == 4),
}
per_file.append(
f"{swc_path.name}: nodes={len(rows)}, type_changes={type_changes}, "
f"radius_changes={radius_changes}, out_types(soma/axon/basal/apic)="
f"{out_counts[1]}/{out_counts[2]}/{out_counts[3]}/{out_counts[4]}"
)
change_details.extend(
_build_change_details(
swc_path.name,
rows,
orig_types,
types,
orig_radii,
radii,
)
)
except Exception as e:
failures.append(f"{swc_path.name}: {e}")
zip_path: str | None = None
if opts.zip_output and processed > 0:
zip_target = in_dir / f"{out_dir.name}.zip"
with zipfile.ZipFile(zip_target, "w", compression=zipfile.ZIP_DEFLATED) as zf:
for f in sorted(out_dir.glob("*.swc")):
zf.write(f, arcname=f"{out_dir.name}/{f.name}")
zip_path = str(zip_target)
payload = {
"folder": str(in_dir),
"out_dir": str(out_dir),
"zip_path": zip_path,
"files_total": len(swc_files),
"files_processed": processed,
"files_failed": len(failures),
"total_nodes": total_nodes,
"total_type_changes": total_type_changes,
"total_radius_changes": total_radius_changes,
"failures": failures,
"per_file": per_file,
"change_details": change_details,
}
log_path = write_text_report(
operation_report_path_for_folder(in_dir, "batch_auto_typing", output_dir=out_dir, timestamp=run_timestamp),
format_auto_typing_report_text(payload),
)
return RuleBatchResult(
folder=str(in_dir),
out_dir=str(out_dir),
zip_path=zip_path,
files_total=len(swc_files),
files_processed=processed,
files_failed=len(failures),
total_nodes=total_nodes,
total_type_changes=total_type_changes,
total_radius_changes=total_radius_changes,
failures=failures,
per_file=per_file,
log_path=log_path,
)