import re
import libcst as cst
from libcst.metadata import PositionProvider, ScopeProvider
from .utils import get_docstring, case_insensitive_replace, create_nested_attribute_or_name
from pathlib import Path
import json
def load_json_file(filepath):
with open(filepath, 'r', encoding='utf-8') as file:
data = json.load(file)
return data
class APITransformer(cst.CSTTransformer):
"""
cst transformer module to map torch api to msadapter api
"""
METADATA_DEPENDENCIES = (PositionProvider, ScopeProvider)
def __init__(self, current_name, new_name, string_mapping):
self.string_mapping = string_mapping
self.current_name = current_name
self.new_name = new_name
self.root = None
self.alias_map = {}
self.support_api = load_json_file(f"{Path(__file__).parent}/../mapping_resources/api_mapping.json")
self.matched = False
def _update_docstring(self, original_node, updated_node):
docstring = get_docstring(original_node)
if docstring:
new_value = case_insensitive_replace(docstring.value, self.current_name, self.new_name)
new_statementline = cst.SimpleStatementLine(body=[cst.Expr(cst.SimpleString(new_value))])
return updated_node.with_changes(
body=updated_node.body.with_changes(
body=(new_statementline,) + updated_node.body.body[1:]
)
)
return updated_node
def leave_FormattedStringText(self, original_node, updated_node):
new_value = original_node.value
for (current_name, new_name) in self.string_mapping:
pattern = fr'\b{current_name}\.'
new_value = re.sub(pattern, f'{new_name}.', new_value)
return updated_node.with_changes(value=new_value)
def leave_SimpleString(self, original_node, updated_node):
new_value = original_node.value
for (current_name, new_name) in self.string_mapping:
if new_value.strip('"\'').startswith(current_name):
new_value = re.sub(current_name, new_name, new_value)
return updated_node.with_changes(value=new_value)
def _flat_alias_if_possible(self, call_split):
alias = call_split[0]
module = self.alias_map.get(alias)
if module is not None:
call_split[0] = module
return '.'.join(call_split)
def visit_Module(self, node):
self.root = node
def _parse_alias(self, code):
pattern = r'(.*) +as +(.*)'
m = re.match(pattern, code)
if m:
full_path, alias = m.groups()
return full_path, alias
return None, None
def visit_Import(self, node):
guard_this_import = False
for import_alias in node.names:
code = self.root.code_for_node(import_alias).strip(' ,')
full_path, alias = self._parse_alias(code)
if alias:
self.alias_map[alias] = full_path
if 'safetensors' in code:
guard_this_import = True
return not guard_this_import
def visit_ImportFrom(self, node):
if node.module is None:
return True
if isinstance(node.names, cst.ImportStar):
return True
module = self.root.code_for_node(node.module)
guard_this_import = 'safetensors' in module
for import_alias in node.names:
code = self.root.code_for_node(import_alias).strip(' ,')
sub_path, alias = self._parse_alias(code)
if alias is None:
sub_path, alias = code, code
self.alias_map[alias] = f'{module}.{sub_path}'
return not guard_this_import
def leave_Call(self, original_node, updated_node):
if isinstance(original_node.func, cst.Call):
return updated_node
if isinstance(original_node.func.value, cst.Call):
return updated_node
code = self.root.code_for_node(original_node.func)
call_split = code.split('.')
de_aliased_call = self._flat_alias_if_possible(call_split)
mapped_call = self.support_api.get(de_aliased_call)
if mapped_call is not None:
self.matched = True
return updated_node.with_changes(func=create_nested_attribute_or_name(mapped_call))
return updated_node
def leave_Name(self, original_node, updated_node):
if original_node.value == f'{self.current_name}_npu':
return updated_node.with_changes(value=f'{self.new_name}_npu')
if original_node.value == self.current_name:
return updated_node.with_changes(value=self.new_name)
return updated_node
def leave_Comment(self, original_node, updated_node):
new_value = case_insensitive_replace(original_node.value, self.current_name, self.new_name)
return updated_node.with_changes(value=new_value)
def leave_ClassDef(self, original_node, updated_node):
return self._update_docstring(original_node, updated_node)
def leave_FunctionDef(self, original_node, updated_node):
return self._update_docstring(original_node, updated_node)
def leave_Module(self, original_node, updated_node):
if not self.matched:
return updated_node
import_stmt = "import mindspore"
new_import = cst.parse_statement(import_stmt)
nodes = list(updated_node.body)
insert_index = 0
for i, node in enumerate(nodes):
if '__future__' not in self.root.code_for_node(node):
insert_index = i
break
new_body = nodes[:insert_index] + [new_import] + nodes[insert_index:]
return updated_node.with_changes(body=new_body)