"""
Phase 2: Parse Compose UI Trees
Static analysis of @Composable function bodies into hierarchical JSON component trees.
Tracks brace-depth nesting, detects composable calls, extracts parameters, handles
conditional blocks, and maps Salt UI components.
Usage:
python3 parse_compose_tree.py <project_root> <discovery_json> <output_dir>
Output:
<output_dir>/compose_trees/<ScreenName>.json (one per screen composable)
"""
import json
import os
import re
import sys
import xml.etree.ElementTree as ET
from pathlib import Path
sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
from salt_ui_mapping import get_component_info, is_non_ui_call, COMPOSE_NON_UI
class StringResourceResolver:
"""Resolves R.string.xxx references to actual string values from res/values/strings.xml."""
def __init__(self):
self.strings = {}
def load(self, project_root: str):
"""Load all string resources from res/values/ directories."""
res_base = os.path.join(project_root, "app", "src", "main", "res")
if not os.path.isdir(res_base):
return
for values_dir_name in sorted(os.listdir(res_base)):
if not values_dir_name.startswith("values"):
continue
values_dir = os.path.join(res_base, values_dir_name)
if not os.path.isdir(values_dir):
continue
for fname in os.listdir(values_dir):
if fname.endswith(".xml"):
self._parse_xml(os.path.join(values_dir, fname),
is_default=(values_dir_name == "values"))
def _parse_xml(self, path: str, is_default: bool = False):
try:
tree = ET.parse(path)
except ET.ParseError:
return
root = tree.getroot()
for elem in root:
if elem.tag == "string":
name = elem.get("name", "")
text = (elem.text or "").strip()
if name and (is_default or name not in self.strings):
self.strings[name] = text
def resolve(self, param_value: str) -> str:
"""Resolve stringResource(R.string.xxx) or stringResource(id = R.string.xxx) to actual text."""
m = re.search(r'stringResource\s*\(\s*(?:id\s*=\s*)?R\.string\.(\w+)', param_value)
if m:
key = m.group(1)
if key in self.strings:
return self.strings[key]
return param_value
def resolve_tree(self, nodes: list):
"""Recursively resolve string resources in tree node params."""
if not nodes:
return
for node in nodes:
if not isinstance(node, dict):
continue
if "params" in node:
for key, val in node["params"].items():
if isinstance(val, str) and "R.string." in val:
resolved = self.resolve(val)
if resolved != val:
node["params"][key] = resolved
if "resolved_strings" not in node:
node["resolved_strings"] = {}
node["resolved_strings"][key] = resolved
self.resolve_tree(node.get("children", []))
self.resolve_tree(node.get("then_children", []))
self.resolve_tree(node.get("else_children", []))
for branch in node.get("branches", []):
self.resolve_tree(branch.get("children", []))
class CustomComposableExpander:
"""One-level expansion of custom composable definitions."""
def __init__(self, project_root: str):
self.project_root = project_root
self._source_cache: dict[str, str] = {}
self._kt_index: dict[str, str] | None = None
def _build_kt_index(self):
"""Build index of all .kt files for name lookup."""
if self._kt_index is not None:
return
self._kt_index = {}
for root, _, files in os.walk(self.project_root):
for f in files:
if f.endswith(".kt"):
self._kt_index[f[:-3]] = os.path.join(root, f)
def _read_source(self, path: str) -> str:
if path not in self._source_cache:
with open(path, "r", encoding="utf-8", errors="ignore") as f:
self._source_cache[path] = f.read()
return self._source_cache[path]
def _find_composable_source(self, name: str, context_source: str = "") -> str | None:
"""Find source file containing @Composable fun <name>(...)."""
if context_source and re.search(rf'\bfun\s+{re.escape(name)}\s*\(', context_source):
return context_source
self._build_kt_index()
if name in self._kt_index:
return self._read_source(self._kt_index[name])
for fname, fpath in self._kt_index.items():
if "/ui/" not in fpath and "\\ui\\" not in fpath:
continue
src = self._read_source(fpath)
if re.search(rf'@Composable\s[^{{]*\bfun\s+{re.escape(name)}\s*\(', src, re.DOTALL):
return src
return None
def expand_custom_nodes(self, nodes: list, context_source: str = "", depth: int = 0):
"""Expand custom composable nodes one level deep."""
if not nodes or depth > 1:
return
for node in nodes:
if not isinstance(node, dict):
continue
if node.get("custom") and node.get("note") == "needs_resolution":
comp_name = node.get("type", "")
src = self._find_composable_source(comp_name, context_source)
if src:
parser = ComposeTreeParser(src, comp_name)
children = parser.parse_tree()
if children:
node["expanded_children"] = children
node["note"] = "auto_expanded_one_level"
self.expand_custom_nodes(node.get("children", []), context_source, depth + 1)
self.expand_custom_nodes(node.get("then_children", []), context_source, depth + 1)
self.expand_custom_nodes(node.get("else_children", []), context_source, depth + 1)
for branch in node.get("branches", []):
self.expand_custom_nodes(branch.get("children", []), context_source, depth + 1)
class ComposeTreeParser:
"""Parses a @Composable function body into a hierarchical component tree."""
def __init__(self, source: str, function_name: str):
self.source = source
self.function_name = function_name
self.lines = source.split("\n")
def find_function_body(self) -> tuple[int, int] | None:
"""Find the start and end line indices of the @Composable function body."""
func_pattern = re.compile(rf'\bfun\s+{re.escape(self.function_name)}\s*\(')
start_line = None
for i, line in enumerate(self.lines):
if func_pattern.search(line):
start_line = i
break
if start_line is None:
return None
brace_depth = 0
body_start = None
for i in range(start_line, len(self.lines)):
for ch in self.lines[i]:
if ch == "{":
brace_depth += 1
if body_start is None:
body_start = i
elif ch == "}":
brace_depth -= 1
if brace_depth == 0 and body_start is not None:
return (body_start, i)
return None
def extract_function_params(self) -> list[dict]:
"""Extract the function's parameters from its signature."""
func_pattern = re.compile(rf'\bfun\s+{re.escape(self.function_name)}\s*\(')
start_line = None
for i, line in enumerate(self.lines):
if func_pattern.search(line):
start_line = i
break
if start_line is None:
return []
sig_lines = []
paren_depth = 0
for i in range(start_line, min(start_line + 50, len(self.lines))):
line = self.lines[i]
sig_lines.append(line)
paren_depth += line.count("(") - line.count(")")
if paren_depth <= 0 and ")" in line:
break
signature = " ".join(sig_lines)
m = re.search(r'\(\s*(.*)\s*\)', signature, re.DOTALL)
if not m:
return []
param_text = m.group(1)
params = []
for param in self._split_params(param_text):
param = param.strip()
if not param:
continue
pm = re.match(r'(\w+)\s*:\s*([^=]+?)(?:\s*=\s*(.+))?$', param.strip())
if pm:
params.append({
"name": pm.group(1).strip(),
"type": pm.group(2).strip(),
"default": pm.group(3).strip() if pm.group(3) else None,
})
return params
def _split_params(self, text: str) -> list[str]:
"""Split parameter text by commas, respecting nested parens/angles/braces."""
parts = []
depth = 0
current = []
for ch in text:
if ch in "(<{":
depth += 1
current.append(ch)
elif ch in ")>}":
depth -= 1
current.append(ch)
elif ch == "," and depth == 0:
parts.append("".join(current))
current = []
else:
current.append(ch)
if current:
parts.append("".join(current))
return parts
def parse_tree(self) -> dict | None:
"""Parse the function body into a hierarchical component tree."""
bounds = self.find_function_body()
if not bounds:
return None
body_start, body_end = bounds
body_lines = self.lines[body_start:body_end + 1]
body_text = "\n".join(body_lines)
first_brace = body_text.index("{")
last_brace = body_text.rindex("}")
inner_text = body_text[first_brace + 1:last_brace]
children = self._parse_block(inner_text, body_start + 1)
return children
def _parse_block(self, text: str, base_line: int = 0, _depth: int = 0) -> list[dict]:
"""Parse a block of Compose code into a list of component nodes."""
if _depth > 20 or not text.strip():
return []
nodes = []
lines = text.split("\n")
i = 0
while i < len(lines):
line = lines[i].strip()
if not line or line.startswith("//") or line.startswith("*") or line.startswith("/*"):
i += 1
continue
if re.match(r'(val|var|fun )\s', line) and not re.match(r'(val|var)\s+\w+\s*=\s*[A-Z]', line):
brace_d = line.count("{") - line.count("}")
paren_d = line.count("(") - line.count(")")
while (brace_d > 0 or paren_d > 0) and i + 1 < len(lines):
i += 1
line = lines[i].strip()
brace_d += line.count("{") - line.count("}")
paren_d += line.count("(") - line.count(")")
i += 1
continue
if_match = re.match(r'if\s*\((.+)\)\s*\{?\s*$', line)
if not if_match:
if re.match(r'if\s*\(', line):
collected = line
paren_d = line.count("(") - line.count(")")
while paren_d > 0 and i + 1 < len(lines):
i += 1
collected += " " + lines[i].strip()
paren_d += lines[i].count("(") - lines[i].count(")")
if_match = re.match(r'if\s*\((.+)\)\s*\{?\s*$', collected)
if if_match:
condition = if_match.group(1).strip()
if_block, end_i = self._extract_brace_block(lines, i)
i = end_i + 1
node = {
"type": "ConditionalBlock",
"condition": condition,
"then_children": self._parse_block(if_block, base_line + i, _depth + 1),
}
if i < len(lines) and lines[i].strip().startswith("else"):
else_line = lines[i].strip()
if "else if" in else_line or "else if" in else_line:
pass
else_block, end_i = self._extract_brace_block(lines, i)
i = end_i + 1
node["else_children"] = self._parse_block(else_block, base_line + i, _depth + 1)
if node.get("then_children") or node.get("else_children"):
nodes.append(node)
continue
when_match = re.match(r'when\s*\((.+)\)\s*\{', line)
if when_match:
expr = when_match.group(1).strip()
when_block, end_i = self._extract_brace_block(lines, i)
i = end_i + 1
branches = self._parse_when_branches(when_block, _depth)
if branches:
nodes.append({
"type": "WhenBlock",
"expression": expr,
"branches": branches,
})
continue
comp_match = re.match(r'([A-Z][a-zA-Z0-9]*)\s*(\(|\{)', line)
if comp_match:
comp_name = comp_match.group(1)
if is_non_ui_call(comp_name):
brace_d = line.count("{") - line.count("}")
paren_d = line.count("(") - line.count(")")
while (brace_d > 0 or paren_d > 0) and i + 1 < len(lines):
i += 1
brace_d += lines[i].count("{") - lines[i].count("}")
paren_d += lines[i].count("(") - lines[i].count(")")
i += 1
continue
call_text, end_i = self._extract_composable_call(lines, i)
i = end_i + 1
node = self._parse_composable_call(comp_name, call_text, _depth)
if node:
nodes.append(node)
continue
items_match = re.match(r'(items|itemsIndexed|item|stickyHeader)\s*\(', line)
if items_match:
call_name = items_match.group(1)
call_text, end_i = self._extract_composable_call(lines, i)
i = end_i + 1
trailing_lambda = self._extract_trailing_lambda(call_text)
children = []
if trailing_lambda:
children = self._parse_block(trailing_lambda, base_line + i, _depth + 1)
if children:
nodes.append({
"type": f"LazyListScope.{call_name}",
"params": self._extract_params_text(call_text),
"children": children,
})
continue
i += 1
return nodes
def _extract_brace_block(self, lines: list[str], start_i: int) -> tuple[str, int]:
"""Extract a {...} block starting from line start_i. Returns (block_content, end_line_index)."""
brace_depth = 0
collecting = False
block_lines = []
i = start_i
found_opening = False
while i < len(lines):
line = lines[i]
line_content_parts = []
for j, ch in enumerate(line):
if ch == "{":
brace_depth += 1
if not found_opening:
found_opening = True
collecting = True
continue
elif ch == "}":
brace_depth -= 1
if brace_depth == 0 and collecting:
remaining = "".join(line_content_parts)
if remaining.strip():
block_lines.append(remaining)
return ("\n".join(block_lines), i)
if collecting:
line_content_parts.append(ch)
if collecting and found_opening:
block_lines.append("".join(line_content_parts))
i += 1
if not found_opening:
return ("", start_i)
return ("\n".join(block_lines), max(i - 1, start_i))
def _extract_composable_call(self, lines: list[str], start_i: int) -> tuple[str, int]:
"""Extract a full composable call including parens and trailing lambda."""
call_lines = []
brace_depth = 0
paren_depth = 0
i = start_i
while i < len(lines):
line = lines[i]
call_lines.append(line)
for j, ch in enumerate(line):
if ch == "(":
paren_depth += 1
elif ch == ")":
paren_depth -= 1
elif ch == "{":
brace_depth += 1
elif ch == "}":
brace_depth -= 1
if brace_depth <= 0 and paren_depth <= 0 and i > start_i:
return ("\n".join(call_lines), i)
if brace_depth == 0 and paren_depth == 0:
return ("\n".join(call_lines), i)
i += 1
return ("\n".join(call_lines), i - 1)
def _parse_composable_call(self, name: str, call_text: str, _depth: int = 0) -> dict | None:
"""Parse a composable call into a node with params and children."""
comp_info = get_component_info(name)
node = {
"type": name,
}
if comp_info:
node["salt_semantic"] = comp_info.get("type")
node["library"] = comp_info.get("library")
node["arkts_hint"] = comp_info.get("arkts_hint")
if comp_info.get("clickable"):
node["clickable"] = True
else:
node["custom"] = True
node["note"] = "needs_resolution"
params = self._extract_params_text(call_text)
if params:
node["params"] = params
actions = self._extract_actions(call_text)
if actions:
node["actions"] = actions
trailing_lambda = self._extract_trailing_lambda(call_text)
if trailing_lambda:
children = self._parse_block(trailing_lambda, 0, _depth + 1)
if children:
node["children"] = children
return node
def _extract_params_text(self, call_text: str) -> dict:
"""Extract named parameters from a composable call as raw strings."""
params = {}
paren_start = call_text.find("(")
if paren_start == -1:
return params
depth = 0
paren_end = -1
for i in range(paren_start, len(call_text)):
if call_text[i] == "(":
depth += 1
elif call_text[i] == ")":
depth -= 1
if depth == 0:
paren_end = i
break
if paren_end == -1:
return params
param_text = call_text[paren_start + 1:paren_end]
parts = self._split_at_depth_zero(param_text, ",")
for part in parts:
part = part.strip()
if not part:
continue
eq_match = re.match(r'(\w+)\s*=\s*(.+)', part, re.DOTALL)
if eq_match:
pname = eq_match.group(1).strip()
pvalue = eq_match.group(2).strip()
if pvalue.startswith("{"):
continue
pvalue = " ".join(pvalue.split())
params[pname] = pvalue
else:
part_clean = " ".join(part.split())
if part_clean and not part_clean.startswith("{"):
params[f"_pos_{len([k for k in params if k.startswith('_pos_')])}"] = part_clean
return params
def _extract_actions(self, call_text: str) -> list[dict]:
"""Extract event handler actions from a composable call."""
actions = []
event_patterns = [
(r'onClick\s*=\s*\{([^}]*(?:\{[^}]*\}[^}]*)*)\}', "onClick"),
(r'onLongClick\s*=\s*\{([^}]*(?:\{[^}]*\}[^}]*)*)\}', "onLongClick"),
(r'onChange\s*=\s*\{([^}]*(?:\{[^}]*\}[^}]*)*)\}', "onChange"),
(r'onValueChange\s*=\s*\{([^}]*(?:\{[^}]*\}[^}]*)*)\}', "onValueChange"),
(r'onConfirm\s*=\s*\{([^}]*(?:\{[^}]*\}[^}]*)*)\}', "onConfirm"),
(r'onDismissRequest\s*=\s*\{([^}]*(?:\{[^}]*\}[^}]*)*)\}', "onDismissRequest"),
]
for pattern, event_name in event_patterns:
for m in re.finditer(pattern, call_text, re.DOTALL):
handler_body = m.group(1).strip()
action = {"event": event_name}
nav_match = re.search(r'navController\.\w*[Nn]avigate\s*\(\s*(?:ScreenRoute\.(\w+)|"([^"]*)")', handler_body)
if nav_match:
target = nav_match.group(1) or nav_match.group(2)
action["navigates_to"] = target
activity_match = re.search(r'startActivity\s*\(.*?(\w+)::class', handler_body)
if activity_match:
action["starts_activity"] = activity_match.group(1)
if len(handler_body) < 200:
action["handler_body"] = handler_body
actions.append(action)
return actions
def _extract_trailing_lambda(self, call_text: str) -> str | None:
"""Extract the trailing lambda content from a composable call."""
paren_depth = 0
brace_depth = 0
last_brace_start = -1
in_parens = False
for i, ch in enumerate(call_text):
if ch == "(":
paren_depth += 1
in_parens = True
elif ch == ")":
paren_depth -= 1
elif ch == "{" and paren_depth == 0:
brace_depth += 1
if brace_depth == 1:
last_brace_start = i
elif ch == "}" and paren_depth == 0:
brace_depth -= 1
if last_brace_start == -1:
return None
depth = 0
for i in range(last_brace_start, len(call_text)):
if call_text[i] == "{":
depth += 1
elif call_text[i] == "}":
depth -= 1
if depth == 0:
content = call_text[last_brace_start + 1:i]
return content.strip() if content.strip() else None
return None
def _split_at_depth_zero(self, text: str, delimiter: str) -> list[str]:
"""Split text by delimiter only at depth 0 (outside nested parens/braces/angles)."""
parts = []
depth = 0
current = []
in_string = False
escape_next = False
for ch in text:
if escape_next:
current.append(ch)
escape_next = False
continue
if ch == "\\":
escape_next = True
current.append(ch)
continue
if ch == '"' and not in_string:
in_string = True
current.append(ch)
continue
if ch == '"' and in_string:
in_string = False
current.append(ch)
continue
if in_string:
current.append(ch)
continue
if ch in "({<[":
depth += 1
elif ch in ")}>]":
depth -= 1
if ch == delimiter[0] and depth == 0 and len(delimiter) == 1:
parts.append("".join(current))
current = []
continue
current.append(ch)
if current:
parts.append("".join(current))
return parts
def _parse_when_branches(self, when_block: str, _depth: int = 0) -> list[dict]:
"""Parse when block branches."""
branches = []
lines = when_block.split("\n")
i = 0
while i < len(lines):
line = lines[i].strip()
if not line or line.startswith("//"):
i += 1
continue
branch_match = re.match(r'(.+?)\s*->\s*\{?\s*$', line)
if branch_match:
condition = branch_match.group(1).strip()
if "{" in line:
block, end_i = self._extract_brace_block(lines, i)
i = end_i + 1
children = self._parse_block(block, 0, _depth + 1)
else:
children = []
i += 1
branches.append({
"condition": condition,
"children": children,
})
else:
i += 1
return branches
def process_screen(project_root: str, screen_name: str, source_file: str, output_dir: str,
resolver: StringResourceResolver | None = None,
expander: CustomComposableExpander | None = None) -> dict:
"""Process a single screen composable into a tree JSON."""
full_path = os.path.join(project_root, source_file)
if not os.path.exists(full_path):
return {"error": f"Source file not found: {source_file}"}
with open(full_path, "r", encoding="utf-8", errors="ignore") as f:
source = f.read()
parser = ComposeTreeParser(source, screen_name)
params = parser.extract_function_params()
tree = parser.parse_tree()
if resolver and tree:
resolver.resolve_tree(tree)
if expander and tree:
expander.expand_custom_nodes(tree, context_source=source)
result = {
"composable_name": screen_name,
"source_file": source_file,
"function_parameters": params,
"compose_tree": tree if tree else [],
}
return result
def main():
if len(sys.argv) < 4:
print("Usage: python3 parse_compose_tree.py <project_root> <discovery_json> <output_dir>")
sys.exit(1)
project_root = os.path.abspath(sys.argv[1])
discovery_path = os.path.abspath(sys.argv[2])
output_dir = os.path.abspath(sys.argv[3])
trees_dir = os.path.join(output_dir, "compose_trees")
os.makedirs(trees_dir, exist_ok=True)
with open(discovery_path, "r", encoding="utf-8") as f:
discovery = json.load(f)
print(f"[Phase 2] Parsing Compose UI trees")
print(" Loading string resources...")
resolver = StringResourceResolver()
resolver.load(project_root)
print(f" Loaded {len(resolver.strings)} string resources")
expander = CustomComposableExpander(project_root)
processed = 0
errors = 0
for reg in discovery.get("compose_registrations", []):
screen_name = reg.get("screen_composable")
source_file = reg.get("source_file")
if not screen_name or not source_file:
print(f" SKIP: {reg.get('route_constant')} (no screen composable or source)")
continue
print(f" Parsing {screen_name} from {source_file}...")
result = process_screen(project_root, screen_name, source_file, output_dir, resolver, expander)
if "error" in result:
print(f" ERROR: {result['error']}")
errors += 1
else:
tree_count = len(result.get("compose_tree", []))
print(f" OK: {tree_count} top-level nodes")
processed += 1
result["route_constant"] = reg.get("route_constant")
result["route_pattern"] = reg.get("full_route_pattern")
result["arguments"] = reg.get("arguments", [])
out_path = os.path.join(trees_dir, f"{screen_name}.json")
with open(out_path, "w", encoding="utf-8") as f:
json.dump(result, f, indent=2, ensure_ascii=False)
for activity in discovery.get("activities", []):
if activity.get("ui_type") != "compose":
continue
entry_fn = activity.get("compose_entry_function")
source_file = activity.get("source_file")
if not entry_fn or not source_file or "(inline)" in (entry_fn or ""):
simple_name = activity.get("simple_name", "Unknown")
if source_file:
print(f" Parsing Activity {simple_name} (inline setContent)...")
with open(os.path.join(project_root, source_file), "r", encoding="utf-8", errors="ignore") as f:
source = f.read()
composable_fns = re.findall(r'@Composable\s+(?:private\s+)?fun\s+(\w+)\s*\(', source)
for fn_name in composable_fns:
parser = ComposeTreeParser(source, fn_name)
tree = parser.parse_tree()
if tree:
if resolver:
resolver.resolve_tree(tree)
if expander:
expander.expand_custom_nodes(tree, context_source=source)
result = {
"composable_name": fn_name,
"source_file": source_file,
"activity": activity.get("class"),
"compose_tree": tree,
}
out_path = os.path.join(trees_dir, f"Activity_{simple_name}_{fn_name}.json")
with open(out_path, "w", encoding="utf-8") as f:
json.dump(result, f, indent=2, ensure_ascii=False)
processed += 1
continue
for comp in discovery.get("shell_components", []):
name = comp.get("name")
source_file = comp.get("source_file")
if not name or not source_file:
continue
print(f" Parsing shell component {name}...")
result = process_screen(project_root, name, source_file, output_dir)
result["shell_role"] = comp.get("role")
out_path = os.path.join(trees_dir, f"Shell_{name}.json")
with open(out_path, "w", encoding="utf-8") as f:
json.dump(result, f, indent=2, ensure_ascii=False)
processed += 1
print(f"\n[Phase 2 Complete]")
print(f" Processed: {processed}")
print(f" Errors: {errors}")
print(f" Output: {trees_dir}/")
if __name__ == "__main__":
main()