"""
Phase 2 v2: Parse Compose UI Trees using tree-sitter AST
Replaces regex+brace-depth parsing with proper AST analysis via tree-sitter-kotlin.
Handles all Kotlin patterns: lambda defaults, CompositionLocalProvider, nested control
flow, Modifier chains, and recursive custom composable expansion.
Usage:
python3 parse_compose_tree_v2.py <project_root> <discovery_json> <output_dir>
Output:
<output_dir>/compose_trees/<ScreenName>.json
"""
import json
import os
import re
import sys
from pathlib import Path
import tree_sitter
import tree_sitter_kotlin as tsk
sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
from salt_ui_mapping import get_component_info, is_non_ui_call, COMPOSE_NON_UI
KOTLIN = tree_sitter.Language(tsk.language())
import xml.etree.ElementTree as ET
class StringResourceResolver:
def __init__(self):
self.strings: dict[str, str] = {}
def load(self, project_root: str):
res_base = os.path.join(project_root, "app", "src", "main", "res")
if not os.path.isdir(res_base):
return
for d in sorted(os.listdir(res_base)):
if d.startswith("values"):
vdir = os.path.join(res_base, d)
if os.path.isdir(vdir):
for f in os.listdir(vdir):
if f.endswith(".xml"):
self._parse(os.path.join(vdir, f), d == "values")
def _parse(self, path, is_default):
try:
root = ET.parse(path).getroot()
except ET.ParseError:
return
for e in root:
if e.tag == "string":
name = e.get("name", "")
text = (e.text or "").strip()
if name and (is_default or name not in self.strings):
self.strings[name] = text
def resolve(self, text: str) -> str | None:
m = re.search(r'R\.string\.(\w+)', text)
if m and m.group(1) in self.strings:
return self.strings[m.group(1)]
return None
def node_text(node) -> str:
return node.text.decode("utf-8", errors="replace") if node else ""
def find_children(node, type_name: str):
return [c for c in node.children if c.type == type_name]
def find_child(node, type_name: str):
for c in node.children:
if c.type == type_name:
return c
return None
def is_composable_name(name: str) -> bool:
"""A call is a composable if it starts with uppercase and isn't a known non-UI."""
return bool(name) and name[0].isupper() and not is_non_ui_call(name)
def get_call_name(call_node) -> str:
"""Extract the function name from a call_expression node."""
for c in call_node.children:
if c.type == "identifier":
return node_text(c)
if c.type == "navigation_expression":
parts = []
for sub in c.children:
if sub.type == "identifier":
parts.append(node_text(sub))
return ".".join(parts) if parts else node_text(c)
if c.type == "call_expression":
return get_call_name(c)
return ""
def get_simple_call_name(call_node) -> str:
"""Get just the top-level function name (not chained)."""
for c in call_node.children:
if c.type == "identifier":
return node_text(c)
inner = find_child(call_node, "call_expression")
if inner:
return get_simple_call_name(inner)
nav = find_child(call_node, "navigation_expression")
if nav:
ids = [node_text(c) for c in nav.children if c.type == "identifier"]
return ids[-1] if ids else ""
return ""
def parse_modifier_chain(node) -> list[dict]:
"""Parse a Modifier chain (e.g., Modifier.fillMaxSize().padding(16.dp)) into a list."""
mods = []
_collect_modifier_calls(node, mods)
return mods
def _collect_modifier_calls(node, result: list):
"""Recursively collect modifier calls from a chained expression."""
if node is None:
return
if node.type == "call_expression":
nav = find_child(node, "navigation_expression")
args_node = find_child(node, "value_arguments")
if nav:
inner_call = find_child(nav, "call_expression")
if inner_call:
_collect_modifier_calls(inner_call, result)
ids = [node_text(c) for c in nav.children if c.type == "identifier"]
method_name = ids[-1] if ids else ""
if method_name and method_name != "Modifier":
args = _extract_value_args(args_node) if args_node else {}
entry = {"name": method_name, "args": args}
lambda_node = find_child(node, "annotated_lambda")
if lambda_node:
entry["has_lambda"] = True
nav_targets = _find_navigate_calls(lambda_node)
if nav_targets:
entry["navigates_to"] = nav_targets
result.append(entry)
elif args_node:
name = get_simple_call_name(node)
if name and name != "Modifier":
args = _extract_value_args(args_node)
result.append({"name": name, "args": args})
elif node.type == "navigation_expression":
inner_call = find_child(node, "call_expression")
if inner_call:
_collect_modifier_calls(inner_call, result)
def _find_navigate_calls(node) -> list[str]:
"""Find navController.navigate(ScreenRoute.XXX) calls in a subtree."""
targets = []
text = node_text(node)
for m in re.finditer(r'(?:navigate|safeNavigate)\s*\(\s*(?:ScreenRoute\.)?(\w+)', text):
targets.append(m.group(1))
return targets
def _extract_value_args(args_node) -> dict:
"""Extract named and positional arguments from a value_arguments node."""
if args_node is None:
return {}
args = {}
pos_idx = 0
for child in args_node.children:
if child.type == "value_argument":
name_node = find_child(child, "identifier")
eq = any(c.type == "=" for c in child.children)
if name_node and eq:
name = node_text(name_node)
value_parts = []
past_eq = False
for c in child.children:
if c.type == "=":
past_eq = True
continue
if past_eq and c.type != "identifier":
value_parts.append(node_text(c))
elif past_eq and c.type == "identifier" and c != name_node:
value_parts.append(node_text(c))
value = " ".join(value_parts).strip()
if not value:
full = node_text(child)
m = re.match(r'\w+\s*=\s*(.+)', full, re.DOTALL)
value = m.group(1).strip() if m else full
args[name] = value
else:
value = node_text(child).strip()
if value:
args[f"_pos_{pos_idx}"] = value
pos_idx += 1
return args
class ComposeTreeBuilder:
def __init__(self, source: bytes, resolver: StringResourceResolver | None = None):
self.source = source
self.resolver = resolver
self.parser = tree_sitter.Parser(KOTLIN)
self.tree = self.parser.parse(source)
def find_composable_functions(self) -> list[tuple[str, any]]:
"""Find all @Composable function declarations."""
results = []
self._find_composable_fns(self.tree.root_node, results)
return results
def _find_composable_fns(self, node, results):
if node.type == "function_declaration":
mods = find_child(node, "modifiers")
if mods and b"@Composable" in mods.text:
name = find_child(node, "identifier")
if name:
results.append((node_text(name), node))
for child in node.children:
self._find_composable_fns(child, results)
def extract_function_params(self, func_node) -> list[dict]:
"""Extract function parameters."""
params_node = find_child(func_node, "function_value_parameters")
if not params_node:
return []
params = []
for child in params_node.children:
if child.type == "parameter":
name = find_child(child, "identifier")
type_node = find_child(child, "user_type") or find_child(child, "function_type")
pname = node_text(name) if name else ""
ptype = node_text(type_node) if type_node else ""
default = None
eq_found = False
for c in child.parent.children:
if c.type == "=" and c.start_point[0] == child.end_point[0]:
eq_found = True
elif eq_found and c != child:
default = node_text(c).strip()
break
params.append({"name": pname, "type": ptype, "default": default})
return params
def parse_function_body(self, func_node) -> list[dict]:
"""Parse a function body into a list of UI tree nodes."""
body = find_child(func_node, "function_body")
if not body:
return []
block = find_child(body, "block")
if not block:
return []
return self._parse_block(block)
def _parse_block(self, block_node, depth=0) -> list[dict]:
"""Parse a block { ... } into a list of composable tree nodes."""
if depth > 20:
return []
nodes = []
for child in block_node.children:
if child.type in ("{", "}"):
continue
extracted = self._extract_ui_node(child, depth)
if extracted:
if isinstance(extracted, list):
nodes.extend(extracted)
else:
nodes.append(extracted)
return nodes
def _extract_ui_node(self, node, depth=0) -> dict | list[dict] | None:
"""Extract a UI tree node from an AST node."""
if depth > 20:
return None
if node.type == "call_expression":
return self._parse_call_expression(node, depth)
if node.type == "if_expression":
return self._parse_if_expression(node, depth)
if node.type == "when_expression":
return self._parse_when_expression(node, depth)
if node.type == "property_declaration":
return None
if node.type == "for_statement":
return self._parse_for_statement(node, depth)
return None
def _parse_call_expression(self, node, depth) -> dict | None:
"""Parse a composable call_expression into a tree node.
tree-sitter represents `Foo(params) { content }` as:
call_expression ← OUTER
call_expression ← INNER: has identifier + value_arguments
identifier = 'Foo'
value_arguments(...)
annotated_lambda ← trailing lambda on OUTER
Simple calls without trailing lambda:
call_expression
identifier = 'Foo'
value_arguments(...)
"""
inner_call = find_child(node, "call_expression")
trailing_lambda = find_child(node, "annotated_lambda")
if inner_call and trailing_lambda:
name = get_simple_call_name(inner_call)
args_node = find_child(inner_call, "value_arguments")
else:
name = get_simple_call_name(node)
args_node = find_child(node, "value_arguments")
if not name:
return None
TRANSPARENT_CONTAINERS = {"CompositionLocalProvider", "ProvideTextStyle", "MaterialTheme"}
if name in TRANSPARENT_CONTAINERS:
if trailing_lambda:
lit = find_child(trailing_lambda, "lambda_literal")
if lit:
return self._parse_lambda_body(lit, depth + 1)
return None
if is_non_ui_call(name) or (name[0].islower() and name not in ("items", "itemsIndexed", "item", "stickyHeader")):
if trailing_lambda and name in ("items", "itemsIndexed", "item", "stickyHeader"):
lit = find_child(trailing_lambda, "lambda_literal")
if lit:
children = self._parse_lambda_body(lit, depth + 1)
if children:
return {
"type": f"LazyScope.{name}",
"params": _extract_value_args(args_node) if args_node else {},
"children": children,
}
return None
if not is_composable_name(name):
return None
comp_info = get_component_info(name)
result = {"type": name}
if comp_info:
result["salt_semantic"] = comp_info.get("type")
result["library"] = comp_info.get("library")
result["arkts_hint"] = comp_info.get("arkts_hint")
if comp_info.get("clickable"):
result["clickable"] = True
else:
result["custom"] = True
if args_node:
params = _extract_value_args(args_node)
if params:
result["params"] = params
if self.resolver:
resolved = {}
for k, v in params.items():
if isinstance(v, str) and "R.string." in v:
r = self.resolver.resolve(v)
if r:
resolved[k] = r
if resolved:
result["resolved_strings"] = resolved
for k, v in resolved.items():
result["params"][k] = v
modifier_text = result.get("params", {}).get("modifier", "")
if modifier_text and "Modifier" in modifier_text and args_node:
for va in find_children(args_node, "value_argument"):
va_name = find_child(va, "identifier")
if va_name and node_text(va_name) == "modifier":
for c in va.children:
if c.type in ("call_expression", "navigation_expression"):
mods = parse_modifier_chain(c)
if mods:
result["modifiers"] = mods
break
actions = self._extract_actions_from(args_node)
if actions:
result["actions"] = actions
if args_node:
for va in find_children(args_node, "value_argument"):
va_name = find_child(va, "identifier")
if not va_name:
continue
pname = node_text(va_name)
if pname in ("onClick", "onLongClick", "onChange", "onValueChange",
"onConfirm", "onDismissRequest", "onYes", "onNo", "onRefresh",
"modifier"):
continue
lit = find_child(va, "lambda_literal")
if lit:
children = self._parse_lambda_body(lit, depth + 1)
if children:
if "content_params" not in result:
result["content_params"] = {}
result["content_params"][pname] = children
if trailing_lambda:
lit = find_child(trailing_lambda, "lambda_literal")
if lit:
children = self._parse_lambda_body(lit, depth + 1)
if children:
result["children"] = children
return result
def _extract_actions_from(self, args_node) -> list[dict]:
"""Extract event handler actions from value_arguments."""
if not args_node:
return []
actions = []
event_names = {"onClick", "onLongClick", "onChange", "onValueChange",
"onConfirm", "onDismissRequest", "onYes", "onNo", "onRefresh"}
for va in find_children(args_node, "value_argument"):
name_id = find_child(va, "identifier")
if not name_id:
continue
pname = node_text(name_id)
if pname not in event_names:
continue
action = {"event": pname}
handler_text = node_text(va)
nav_targets = _find_navigate_calls(va)
if nav_targets:
action["navigates_to"] = nav_targets[0]
activity_match = re.search(r'(\w+Activity)::class', handler_text)
if activity_match:
action["starts_activity"] = activity_match.group(1)
m = re.search(r'=\s*\{(.*)\}', handler_text, re.DOTALL)
if m:
body = m.group(1).strip()
if len(body) < 150:
action["handler_body"] = body
actions.append(action)
return actions
def _parse_lambda_body(self, lambda_lit, depth) -> list[dict]:
"""Parse a lambda_literal's body for composable children."""
nodes = []
for child in lambda_lit.children:
if child.type in ("{", "}", "lambda_parameters", "->"):
continue
extracted = self._extract_ui_node(child, depth)
if extracted:
if isinstance(extracted, list):
nodes.extend(extracted)
else:
nodes.append(extracted)
return nodes
def _parse_if_expression(self, node, depth) -> dict | None:
"""Parse if/else into a ConditionalBlock node."""
condition = ""
blocks = []
for child in node.children:
if child.type in ("(", ")"):
continue
if child.type == "if":
continue
if child.type == "else":
continue
if child.type not in ("block", "call_expression", "if_expression",
"when_expression") and child.type != "{":
ctext = node_text(child).strip("()")
if ctext and not condition:
condition = ctext
then_children = []
else_children = []
found_else = False
for child in node.children:
if child.type == "else":
found_else = True
continue
if child.type == "block":
parsed = self._parse_block(child, depth + 1)
if not found_else:
then_children = parsed
else:
else_children = parsed
elif child.type == "if_expression" and found_else:
inner = self._parse_if_expression(child, depth + 1)
if inner:
else_children = [inner]
elif child.type == "call_expression":
parsed = self._parse_call_expression(child, depth + 1)
if parsed:
if not found_else:
then_children = [parsed]
else:
else_children = [parsed]
if not then_children and not else_children:
return None
result = {
"type": "ConditionalBlock",
"condition": condition,
"then_children": then_children,
}
if else_children:
result["else_children"] = else_children
return result
def _parse_when_expression(self, node, depth) -> dict | None:
"""Parse when expression into a WhenBlock node."""
subject = ""
subject_node = find_child(node, "when_subject")
if subject_node:
subject = node_text(subject_node).strip("()")
branches = []
for entry in find_children(node, "when_entry"):
condition = ""
children = []
for child in entry.children:
if child.type == "->":
continue
if child.type in ("type_test", "range_test", "expression"):
condition = node_text(child)
elif child.type == "block":
children = self._parse_block(child, depth + 1)
elif child.type == "call_expression":
parsed = self._parse_call_expression(child, depth + 1)
if parsed:
children = [parsed]
elif child.type == "if_expression":
parsed = self._parse_if_expression(child, depth + 1)
if parsed:
children = [parsed]
elif child.type == "when_expression":
parsed = self._parse_when_expression(child, depth + 1)
if parsed:
children = [parsed]
elif child.type == "when_condition":
condition = node_text(child)
if condition or children:
branches.append({"condition": condition, "children": children})
if not branches:
return None
return {
"type": "WhenBlock",
"expression": subject,
"branches": branches,
}
def _parse_for_statement(self, node, depth) -> dict | None:
"""Parse for loop (rare in Compose but used in some screens)."""
children = []
block = find_child(node, "block")
if block:
children = self._parse_block(block, depth + 1)
if not children:
return None
return {
"type": "ForLoop",
"expression": node_text(node).split("{")[0].strip(),
"children": children,
}
class RecursiveExpander:
"""Recursively expand custom composables by parsing their source definitions."""
def __init__(self, project_root: str, resolver: StringResourceResolver | None):
self.project_root = project_root
self.resolver = resolver
self._source_cache: dict[str, bytes] = {}
self._tree_cache: dict[str, ComposeTreeBuilder] = {}
self._kt_index: dict[str, str] | None = None
self._expanding: set[str] = set()
def _build_index(self):
if self._kt_index is not None:
return
self._kt_index = {}
for root, _, files in os.walk(self.project_root):
for f in files:
if f.endswith(".kt"):
self._kt_index[f[:-3]] = os.path.join(root, f)
def _get_builder(self, path: str) -> ComposeTreeBuilder:
if path not in self._tree_cache:
if path not in self._source_cache:
with open(path, "rb") as f:
self._source_cache[path] = f.read()
self._tree_cache[path] = ComposeTreeBuilder(self._source_cache[path], self.resolver)
return self._tree_cache[path]
def _find_composable_def(self, name: str) -> tuple[ComposeTreeBuilder, any] | None:
"""Find a @Composable fun definition by name."""
self._build_index()
candidates = []
if name in self._kt_index:
candidates.append(self._kt_index[name])
for suffix in ("Content", "Body", "View", "UI", "Impl", "Internal"):
base = name.replace(suffix, "")
if base != name and base in self._kt_index:
candidates.append(self._kt_index[base])
for fname, fpath in self._kt_index.items():
if fpath not in candidates:
candidates.append(fpath)
for fpath in candidates:
builder = self._get_builder(fpath)
for fn_name, fn_node in builder.find_composable_functions():
if fn_name == name:
return builder, fn_node
return None
def expand_tree(self, nodes: list[dict], max_depth: int = 3, current_depth: int = 0):
"""Recursively expand custom composable nodes."""
if current_depth >= max_depth:
return
for node in nodes:
if not isinstance(node, dict):
continue
if node.get("custom") and "children" not in node and "expanded_children" not in node:
comp_name = node.get("type", "")
if comp_name and comp_name not in self._expanding:
self._expanding.add(comp_name)
result = self._find_composable_def(comp_name)
if result:
builder, fn_node = result
children = builder.parse_function_body(fn_node)
if children:
node["expanded_children"] = children
node.pop("custom", None)
node["expanded_from"] = comp_name
self.expand_tree(children, max_depth, current_depth + 1)
self._expanding.discard(comp_name)
for key in ("children", "expanded_children", "then_children", "else_children"):
if key in node:
self.expand_tree(node[key], max_depth, current_depth + 1)
for branch in node.get("branches", []):
self.expand_tree(branch.get("children", []), max_depth, current_depth)
def process_screen(project_root: str, screen_name: str, source_file: str,
resolver: StringResourceResolver | None = None) -> dict:
"""Process a single screen composable into a tree JSON."""
full_path = os.path.join(project_root, source_file)
if not os.path.exists(full_path):
return {"error": f"Source file not found: {source_file}"}
with open(full_path, "rb") as f:
source = f.read()
builder = ComposeTreeBuilder(source, resolver)
composables = builder.find_composable_functions()
target_node = None
for fn_name, fn_node in composables:
if fn_name == screen_name:
target_node = fn_node
break
if not target_node:
all_trees = {}
for fn_name, fn_node in composables:
tree = builder.parse_function_body(fn_node)
if tree:
all_trees[fn_name] = tree
if all_trees:
return {
"composable_name": screen_name,
"source_file": source_file,
"note": f"Target function '{screen_name}' not found, extracted {len(all_trees)} composables",
"all_composables": {k: v for k, v in all_trees.items()},
"compose_tree": list(all_trees.values())[0] if len(all_trees) == 1 else [],
}
return {
"composable_name": screen_name,
"source_file": source_file,
"compose_tree": [],
"note": "No composable functions found",
}
params = builder.extract_function_params(target_node)
tree = builder.parse_function_body(target_node)
return {
"composable_name": screen_name,
"source_file": source_file,
"function_parameters": params,
"compose_tree": tree if tree else [],
}
def main():
if len(sys.argv) < 4:
print("Usage: python3 parse_compose_tree_v2.py <project_root> <discovery_json> <output_dir>")
sys.exit(1)
project_root = os.path.abspath(sys.argv[1])
discovery_path = os.path.abspath(sys.argv[2])
output_dir = os.path.abspath(sys.argv[3])
trees_dir = os.path.join(output_dir, "compose_trees")
os.makedirs(trees_dir, exist_ok=True)
with open(discovery_path, "r", encoding="utf-8") as f:
discovery = json.load(f)
print("[Phase 2 v2] Parsing Compose UI trees with tree-sitter AST")
print(" Loading string resources...")
resolver = StringResourceResolver()
resolver.load(project_root)
print(f" Loaded {len(resolver.strings)} string resources")
from salt_ui_mapping import scan_project_widgets
project_widgets = scan_project_widgets(project_root)
if project_widgets:
print(f" Loaded {len(project_widgets)} project-specific widget definitions")
processed = 0
errors = 0
total_nodes = 0
def count_nodes(nodes):
c = 0
for n in nodes:
if isinstance(n, dict):
c += 1
for key in ("children", "expanded_children", "then_children", "else_children"):
c += count_nodes(n.get(key, []))
for branch in n.get("branches", []):
c += count_nodes(branch.get("children", []))
return c
for reg in discovery.get("compose_registrations", []):
screen_name = reg.get("screen_composable")
source_file = reg.get("source_file")
if not screen_name or not source_file:
continue
result = process_screen(project_root, screen_name, source_file, resolver)
nc = count_nodes(result.get("compose_tree", []))
total_nodes += nc
if "error" in result:
print(f" ERROR {screen_name}: {result['error']}")
errors += 1
else:
print(f" {screen_name}: {nc} nodes")
processed += 1
result["route_constant"] = reg.get("route_constant")
result["route_pattern"] = reg.get("full_route_pattern")
result["arguments"] = reg.get("arguments", [])
with open(os.path.join(trees_dir, f"{screen_name}.json"), "w", encoding="utf-8") as f:
json.dump(result, f, indent=2, ensure_ascii=False)
for comp in discovery.get("shell_components", []):
name = comp.get("name")
source_file = comp.get("source_file")
if not name or not source_file:
continue
result = process_screen(project_root, name, source_file, resolver)
nc = count_nodes(result.get("compose_tree", []))
total_nodes += nc
result["shell_role"] = comp.get("role")
print(f" Shell_{name}: {nc} nodes")
processed += 1
with open(os.path.join(trees_dir, f"Shell_{name}.json"), "w", encoding="utf-8") as f:
json.dump(result, f, indent=2, ensure_ascii=False)
for activity in discovery.get("activities", []):
if activity.get("ui_type") != "compose":
continue
source_file = activity.get("source_file")
simple_name = activity.get("simple_name")
if not source_file:
continue
full_path = os.path.join(project_root, source_file)
if not os.path.exists(full_path):
continue
with open(full_path, "rb") as f:
source = f.read()
builder = ComposeTreeBuilder(source, resolver)
for fn_name, fn_node in builder.find_composable_functions():
tree = builder.parse_function_body(fn_node)
nc = count_nodes(tree)
total_nodes += nc
if tree:
out = {
"composable_name": fn_name,
"source_file": source_file,
"activity": activity.get("class"),
"compose_tree": tree,
}
with open(os.path.join(trees_dir, f"Activity_{simple_name}_{fn_name}.json"), "w", encoding="utf-8") as f:
json.dump(out, f, indent=2, ensure_ascii=False)
processed += 1
print(f" Activity_{simple_name}_{fn_name}: {nc} nodes")
print("\n Expanding custom composables...")
expander = RecursiveExpander(project_root, resolver)
expanded_count = 0
def _count_expanded(nodes) -> int:
"""Count nodes that have been expanded (have expanded_from field)."""
c = 0
for n in nodes:
if isinstance(n, dict):
if n.get("expanded_from"):
c += 1
for key in ("children", "expanded_children", "then_children", "else_children"):
c += _count_expanded(n.get(key, []))
for branch in n.get("branches", []):
c += _count_expanded(branch.get("children", []))
return c
for fname in os.listdir(trees_dir):
if not fname.endswith(".json"):
continue
fpath = os.path.join(trees_dir, fname)
with open(fpath, "r", encoding="utf-8") as f:
data = json.load(f)
tree = data.get("compose_tree", [])
before_expanded = _count_expanded(tree)
expander.expand_tree(tree, max_depth=5)
after_expanded = _count_expanded(tree)
newly_expanded = after_expanded - before_expanded
if newly_expanded > 0:
expanded_count += newly_expanded
data["compose_tree"] = tree
with open(fpath, "w", encoding="utf-8") as f:
json.dump(data, f, indent=2, ensure_ascii=False)
print(f" Expanded {expanded_count} custom composables")
print(f"\n[Phase 2 v2 Complete]")
print(f" Processed: {processed}")
print(f" Errors: {errors}")
print(f" Total nodes: {total_nodes}")
print(f" Custom expanded: {expanded_count}")
print(f" Output: {trees_dir}/")
def _count_custom(nodes) -> int:
c = 0
for n in nodes:
if isinstance(n, dict):
if n.get("custom"):
c += 1
for key in ("children", "expanded_children", "then_children", "else_children"):
c += _count_custom(n.get(key, []))
for branch in n.get("branches", []):
c += _count_custom(branch.get("children", []))
return c
if __name__ == "__main__":
main()