YYour Namecommit message
d7290920创建于 2024年10月19日历史提交
"""
Recreate the Core ML model from scratch using
coremltools' neural_network.NeuralNetworkBuilder
"""
import coremltools
import coremltools.models.datatypes as datatypes
from coremltools.models import neural_network as neural_network
from coremltools.models.utils import save_spec
import numpy as np

# get weights
from transformers import GPT2LMHeadModel, GPT2Tokenizer

model_name = "./distilgpt2-base-pretrained-he"
save_directory = "tmp/coreml/"
#!mkdir -p $save_directory
file_name = "model.mlmodel"

tokenizer = GPT2Tokenizer.from_pretrained(model_name)
lm_head_model = GPT2LMHeadModel.from_pretrained(model_name).eval()
model = lm_head_model.transformer

wte = model.wte.weight.data.numpy().transpose() # shape (768, 50257) /!\ i hate this
wpe = model.wpe.weight.data.numpy().transpose() # shape (768, 1024)

sequence_length = 64
steps = 6

# build model
input_features = [
	('input_ids', datatypes.Array(sequence_length)),
	('position_ids', datatypes.Array(sequence_length)),
]
output_features = [('output_logits', None)]

builder = neural_network.NeuralNetworkBuilder(
	input_features,
	output_features,
	mode=None,
	disable_rank5_shape_mapping=True,
)
builder.add_expand_dims(
	name='input_ids_expanded_to_rank5',
	input_name='input_ids',
	output_name='input_ids_expanded_to_rank5',
	axes=(1, 2, 3, 4)
)
builder.add_expand_dims(
	name='position_ids_expanded_to_rank5',
	input_name='position_ids',
	output_name='position_ids_expanded_to_rank5',
	axes=(1, 2, 3, 4)
)
builder.add_embedding(
	name='token_embeddings',
	input_name='input_ids_expanded_to_rank5',
	output_name='token_embeddings',
	W=wte,
	b=None,
	input_dim=50257,
	output_channels=768,
	has_bias=False,
)
builder.add_embedding(
	name='positional_embeddings',
	input_name='position_ids_expanded_to_rank5',
	output_name='positional_embeddings',
	W=wpe,
	b=None,
	input_dim=1024,
	output_channels=768,
	has_bias=False,
)

# Input:, Output: (seq, 1, 768, 1, 1)
builder.add_add_broadcastable(
	name='embeddings_addition',
	input_names=['token_embeddings', 'positional_embeddings'],
	output_name=f'{0}_previous_block'
)

for i in range(steps):
	print(i)
	ln_weight = model.h[i].ln_1.weight.data.numpy().reshape((1, 1, 768, 1, 1))
	ln_bias = model.h[i].ln_1.bias.data.numpy().reshape((1, 1, 768, 1, 1))
	ln_epsilon = model.h[i].ln_1.eps

	builder.add_mvn(
		name=f"{i}_block_ln_1",
		input_name=f"{i}_previous_block",
		# output_name=f"{i}_block_ln_1_output",
		output_name=f"{i}_block_ln_1",
		across_channels=True,
		normalize_variance=True,
		epsilon=ln_epsilon
	)

	builder.add_scale(
		name=f"{i}_block_ln_1_scaled",
		input_name=f"{i}_block_ln_1",
		output_name=f"{i}_block_ln_1_scaled",
		W=ln_weight,
		b=ln_bias,
		has_bias=True,
		shape_scale=[768],
		shape_bias=[768]
	)

	builder.add_transpose(
		name=f"{i}_block_ln_1_reshape",
		input_name=f"{i}_block_ln_1_scaled",
		output_name=f"{i}_block_ln_1_scaled_transposed",
		axes=(1, 0, 2, 3, 4)
	)


	conv_1D_bias = model.h[i].attn.c_attn.bias.data.numpy().reshape((1, 1, 2304, 1, 1))
	conv_1D_weights = model.h[i].attn.c_attn.weight.data.numpy().transpose().reshape((1, 768, 2304, 1, 1))

	builder.add_inner_product(
		name=f"{i}_block_attn_conv",
		input_name=f"{i}_block_ln_1_scaled_transposed",
		output_name=f"{i}_block_attn_conv",
		input_channels=768,
		output_channels=2304,
		W=conv_1D_weights,
		b=conv_1D_bias,
		has_bias=True
	)

	builder.add_split(
		name=f"{i}_block_attn_qkv_split",
		input_name=f"{i}_block_attn_conv",
		output_names=[f"{i}_block_attn_q", f"{i}_block_attn_k", f"{i}_block_attn_v"]
	)

	builder.add_rank_preserving_reshape(
		name=f"{i}_block_attn_q_reshape",
		input_name=f"{i}_block_attn_q",
		output_name=f"{i}_block_attn_q_reshape",
		output_shape=(1, 1, sequence_length, 12, 64)
	)

	builder.add_transpose(
		name=f"{i}_block_attn_q_reshape_permuted",
		input_name=f"{i}_block_attn_q_reshape",
		output_name=f"{i}_block_attn_q_reshape_permuted",
		axes=(0, 1, 3, 2, 4)
	)

	builder.add_rank_preserving_reshape(
		name=f"{i}_block_attn_k_reshape",
		input_name=f"{i}_block_attn_k",
		output_name=f"{i}_block_attn_k_reshape",
		output_shape=(1, 1, sequence_length, 12, 64)
	)

	builder.add_transpose(
		name=f"{i}_block_attn_k_reshape_permuted",
		input_name=f"{i}_block_attn_k_reshape",
		output_name=f"{i}_block_attn_k_reshape_permuted",
		axes=(0, 1, 3, 4, 2)
	)

	builder.add_rank_preserving_reshape(
		name=f"{i}_block_attn_v_reshape",
		input_name=f"{i}_block_attn_v",
		output_name=f"{i}_block_attn_v_reshape",
		output_shape=(1, 1, sequence_length, 12, 64)
	)

	builder.add_transpose(
		name=f"{i}_block_attn_v_reshape_permuted",
		input_name=f"{i}_block_attn_v_reshape",
		output_name=f"{i}_block_attn_v_reshape_permuted",
		axes=(0, 1, 3, 2, 4)
	)

	builder.add_batched_mat_mul(
		name=f"{i}_block_attn_qv_matmul",
		input_names=[f"{i}_block_attn_q_reshape_permuted", f"{i}_block_attn_k_reshape_permuted"],
		output_name=f"{i}_block_attn_qv_matmul"
	)

	builder.add_scale(
		name=f"{i}_block_attn_qv_matmul_scaled",
		input_name=f"{i}_block_attn_qv_matmul",
		output_name=f"{i}_block_attn_qv_matmul_scaled",
		W=np.array(1/8),
		b=0,
		has_bias=False
	)

	bias_0 = model.h[i].attn.bias
	nd = ns = sequence_length
	b = (model.h[i].attn.bias[:, :, ns-nd:ns, :ns]).unsqueeze(0)

	builder.add_scale(
		name=f"{i}_block_attn_bias",
		input_name=f"{i}_block_attn_qv_matmul_scaled",
		output_name=f"{i}_block_attn_bias",
		W=b,
		b=None,
		has_bias=False,
		shape_scale=[1, sequence_length, sequence_length]
	)

	bias_constant_0 = - 1e4 * (1 - b)

	builder.add_bias(
		name=f"{i}_block_attn_afterbias",
		input_name=f"{i}_block_attn_bias",
		output_name=f"{i}_block_attn_afterbias",
		# output_name=f"output_logits",
		b=bias_constant_0,
		shape_bias=[1, sequence_length, sequence_length],
	)

	builder.add_squeeze(
		name=f"{i}_squeezit",
		input_name=f"{i}_block_attn_afterbias",
		output_name=f"{i}_squeezit",
		axes=[0, 1]
	)

	builder.add_softmax(
		name=f"{i}_block_attn_softmax",
		input_name=f"{i}_squeezit",
		output_name=f"{i}_block_attn_softmax",
	)

	builder.add_expand_dims(
		name=f"{i}_expandit",
		input_name=f"{i}_block_attn_softmax",
		output_name=f"{i}_expandit",
		axes=[0, 1]
	)

	builder.add_batched_mat_mul(
		name=f"{i}_block_full_attention",
		input_names=[f"{i}_expandit", f"{i}_block_attn_v_reshape_permuted"],
		output_name=f"{i}_block_full_attention"
	)

	builder.add_transpose(
		name=f"{i}_block_full_attention_merged_t",
		input_name=f"{i}_block_full_attention",
		output_name=f"{i}_block_full_attention_merged_t",
		axes=[0, 1, 3, 2, 4]
	)

	builder.add_rank_preserving_reshape(
		name=f"{i}_block_full_attention_merged",
		input_name=f"{i}_block_full_attention_merged_t",
		output_name=f"{i}_block_full_attention_merged",
		output_shape=[1, 1, 1, sequence_length, 768]
	)

	builder.add_transpose(
		name=f"{i}_block_attn_conv_proj_t",
		input_name=f"{i}_block_full_attention_merged",
		output_name=f"{i}_block_attn_conv_proj_t",
		axes=[0, 3, 4, 1, 2]
	)

	conv_1D_proj_bias = model.h[i].attn.c_proj.bias.data.numpy().reshape((1, 1, 768, 1, 1))
	conv_1D_proj_weights = model.h[i].attn.c_proj.weight.data.numpy().transpose().reshape((1, 768, 768, 1, 1))

	# Input:, Output: (1, 3, 768, 1, 1)
	builder.add_inner_product(
		name=f"{i}_block_attn_conv_proj",
		input_name=f"{i}_block_attn_conv_proj_t",
		output_name=f"{i}_block_attn_conv_proj",
		input_channels=768,
		output_channels=768,
		W=conv_1D_proj_weights,
		b=conv_1D_proj_bias,
		has_bias=True
	)

	# Input: (seq, 1, 768, 1, 1), Output: (1, seq, 768, 1, 1)
	builder.add_transpose(
		name=f"{i}_previous_block_t",
		input_name=f'{i}_previous_block',
		output_name=f"{i}_previous_block_t",
		axes=[1, 0, 2, 3, 4]
	)

	# Input: [(1, seq, 768, 1, 1), (1, seq, 768, 1, 1)], Output: (1, seq, 768, 1, 1)
	builder.add_add_broadcastable(
		name=f"{i}_block_xa_sum",
		input_names=[f"{i}_previous_block_t", f"{i}_block_attn_conv_proj"],
		output_name=f"{i}_block_xa_sum",
		# output_name=f"output_logits"
	)

	ln_2_weight = model.h[i].ln_2.weight.data.numpy().reshape((1, 1, 768, 1, 1))
	ln_2_bias = model.h[i].ln_2.bias.data.numpy().reshape((1, 1, 768, 1, 1))
	ln_2_epsilon = model.h[i].ln_2.eps

	# Input: (1, seq, 768, 1, 1), Output:
	builder.add_mvn(
		name=f"{i}_block_ln_2",
		input_name=f"{i}_block_xa_sum",
		output_name=f"{i}_block_ln_2",
		across_channels=True,
		normalize_variance=True,
		epsilon=ln_2_epsilon
	)

	builder.add_scale(
		name=f"{i}_block_ln_2_scaled",
		input_name=f"{i}_block_ln_2",
		# output_name=f"output_logits",
		output_name=f"{i}_block_ln_2_scaled",
		W=ln_2_weight,
		b=ln_2_bias,
		has_bias=True,
		shape_scale=[768],
		shape_bias=[768]
	)

	mlp_conv_1D_fc_bias = model.h[i].mlp.c_fc.bias.data.numpy().reshape((1, 1, 3072, 1, 1))
	mlp_conv_1D_fc_weights = model.h[i].mlp.c_fc.weight.data.numpy().transpose().reshape((1, 768, 3072, 1, 1))

	# Input:, Output: (1, 3, 3072, 1, 1)
	builder.add_inner_product(
		name=f"{i}_block_mlp_conv_fc",
		input_name=f"{i}_block_ln_2_scaled",
		output_name=f"{i}_block_mlp_conv_fc",
		# output_name=f"output_logits",
		input_channels=768,
		output_channels=3072,
		W=mlp_conv_1D_fc_weights,
		b=mlp_conv_1D_fc_bias,
		has_bias=True
	)

	builder.add_gelu(
		name=f"{i}_block_mlp_gelu",
		input_name=f"{i}_block_mlp_conv_fc",
		output_name=f"{i}_block_mlp_gelu",
		# output_name=f"output_logits",
		mode='TANH_APPROXIMATION'
	)

	mlp_conv_1D_proj_bias = model.h[i].mlp.c_proj.bias.data.numpy().reshape((1, 1, 768, 1, 1))
	mlp_conv_1D_proj_weights = model.h[i].mlp.c_proj.weight.data.numpy().transpose().reshape((1, 3072, 768, 1, 1))

	# Input:, Output: (1, 3, 3072, 1, 1)
	builder.add_inner_product(
		name=f"{i}_block_mlp_conv_proj",
		input_name=f"{i}_block_mlp_gelu",
		output_name=f"{i}_block_mlp_conv_proj",
		# output_name=f"output_logits",
		input_channels=3072,
		output_channels=768,
		W=mlp_conv_1D_proj_weights,
		b=mlp_conv_1D_proj_bias,
		has_bias=True
	)

	builder.add_add_broadcastable(
		name=f"{i}_block_xm_sum",
		input_names=[f"{i}_block_xa_sum", f"{i}_block_mlp_conv_proj"],
		# output_name=f"output_logits"
		output_name=f"{i + 1}_previous_block_final"
	)

	builder.add_transpose(
		name=f"{i}_block_xm_sum_t",
		input_name=f"{i + 1}_previous_block_final",
		output_name=f"{i + 1}_previous_block",
		axes=[1, 0, 2, 3, 4]
	)


ln_f_weight = model.ln_f.weight.data.numpy().reshape((1, 1, 768, 1, 1))
ln_f_bias = model.ln_f.bias.data.numpy().reshape((1, 1, 768, 1, 1))
ln_f_epsilon = model.ln_f.eps

# Input: (1, seq, 768, 1, 1), Output:
builder.add_mvn(
	name=f"ln_f",
	input_name=f"{steps}_previous_block_final",
	output_name=f"ln_f",
	# output_name=f"output_logits",
	across_channels=True,
	normalize_variance=True,
	epsilon=ln_f_epsilon
)

builder.add_scale(
	name=f"ln_f_scaled",
	input_name=f"ln_f",
	output_name=f"ln_f_scaled",
	# output_name=f"output_logits",
	W=ln_f_weight,
	b=ln_f_bias,
	has_bias=True,
	shape_scale=[768],
	shape_bias=[768]
)

lm_head_weights = lm_head_model.lm_head.weight.data.numpy().reshape((1, 50257, 768, 1, 1))

builder.add_inner_product(
	name="lm_head",
	input_name="ln_f_scaled",
	output_name="output_logits",
	input_channels=768,
	output_channels=50257,
	W=lm_head_weights,
	b=None,
	has_bias=False
)

# compile spec to model
mlmodel = coremltools.models.MLModel(builder.spec)

#save_spec(builder.spec, f'./{model_name}-{sequence_length}-{steps}.mlmodel')
save_spec(builder.spec, f'./{save_directory}{file_name}')
# model = coremltools.models.MLModel('gpt2.mlmodel')

# input_ids = np.zeros(sequence_length)
# position_ids = np.arange(sequence_length).astype(np.float)

# input_data = {
# 	'input_ids': input_ids,
# 	'position_ids': position_ids,
# }

# predictions = mlmodel.predict(input_data)["output_logits"]
# equal = np.amax(predictions - mlp_conv_proj.detach().numpy())

# print(predictions)


# save_spec(builder.spec, 'gpt2.mlmodel')