import json
import os
import sys
from argparse import ArgumentParser
from collections import OrderedDict
from contextlib import contextmanager
from operator import itemgetter
from time import time
from pyspark import broadcast
from pyspark.sql import Row, SparkSession, Window
from pyspark.sql.functions import *
from pyspark.sql.types import *
LABEL_COL = 0
INT_COLS = list(range(1, 14))
CAT_COLS = list(range(14, 40))
def get_column_counts_with_frequency_limit(df, frequency_limit = None):
cols = ['_c%d' % i for i in CAT_COLS]
df = (df
.select(posexplode(array(*cols)))
.withColumnRenamed('pos', 'column_id')
.withColumnRenamed('col', 'data')
.filter('data is not null')
.groupBy('column_id', 'data')
.count())
if frequency_limit:
frequency_limit = frequency_limit.split(",")
exclude = []
default_limit = None
for fl in frequency_limit:
frequency_pair = fl.split(":")
if len(frequency_pair) == 1:
default_limit = int(frequency_pair[0])
elif len(frequency_pair) == 2:
df = df.filter((col('column_id') != int(frequency_pair[0]) - CAT_COLS[0]) | (col('count') >= int(frequency_pair[1])))
exclude.append(int(frequency_pair[0]))
if default_limit:
remain = [x - CAT_COLS[0] for x in CAT_COLS if x not in exclude]
df = df.filter((~col('column_id').isin(remain)) | (col('count') >= default_limit))
return df
def assign_id_with_window(df):
windowed = Window.partitionBy('column_id').orderBy(desc('count'))
return (df
.withColumn('id', row_number().over(windowed))
.withColumnRenamed('count', 'model_count'))
def assign_low_mem_partial_ids(df):
df = df.orderBy(asc('column_id'), desc('count'))
df = df.withColumn('part_id', spark_partition_id())
return df.withColumn('mono_id', monotonically_increasing_id() - shiftLeft(col('part_id'), 33))
def assign_low_mem_final_ids(df):
sub_model = df.groupBy('column_id', 'part_id').agg(max('mono_id').alias('top'), min('mono_id').alias('bottom'))
sub_model = sub_model.withColumn('diff', col('top') - col('bottom') + 1)
sub_model = sub_model.drop('top')
windowed = Window.partitionBy('column_id').orderBy('part_id').rowsBetween(Window.unboundedPreceding, -1)
sub_model = sub_model.withColumn('running_sum', sum('diff').over(windowed)).na.fill(0, ["running_sum"])
joined = df.withColumnRenamed('column_id', 'i_column_id')
joined = joined.withColumnRenamed('part_id', 'i_part_id')
joined = joined.withColumnRenamed('count', 'model_count')
joined = joined.join(sub_model, (col('i_column_id') == col('column_id')) & (col('part_id') == col('i_part_id')))
ret = joined.select(col('column_id'),
col('data'),
(col('mono_id') - col('bottom') + col('running_sum') + 1).cast(IntegerType()).alias('id'),
col('model_count'))
return ret
def get_column_models(combined_model):
for i in CAT_COLS:
model = (combined_model
.filter('column_id == %d' % (i - CAT_COLS[0]))
.drop('column_id'))
yield i, model
def col_of_rand_long():
return (rand() * (1 << 52)).cast(LongType())
def skewed_join(df, model, col_name, cutoff):
b_model = broadcast(model.filter(col('model_count') >= cutoff)
.withColumnRenamed('data', col_name)
.drop('model_count'))
df = (df
.join(b_model, col_name, how='left')
.withColumnRenamed('id', 'id_tmp'))
if 'ordinal' in df.columns:
rand_column = col('ordinal')
else:
rand_column = col_of_rand_long()
df = df.withColumn('join_rand',
when(col('id_tmp').isNotNull() | col(col_name).isNull(), rand_column)
.otherwise(lit(-1)))
df = df.withColumn(col_name,
when(col('id_tmp').isNotNull(), None)
.otherwise(col(col_name)))
nb_model = (model
.withColumn('join_rand', when(col('model_count') < cutoff, lit(-1)).otherwise(lit(-2)))
.filter(col('model_count') < cutoff)
.withColumnRenamed('data', col_name)
.drop('model_count'))
df = (df
.join(nb_model, ['join_rand', col_name], how='left')
.drop(col_name, 'join_rand')
.withColumn(col_name, coalesce(col('id'), col('id_tmp')))
.drop('id', 'id_tmp'))
return df
def apply_models(df, models, broadcast_model = False, skew_broadcast_pct = 1.0):
models = sorted(models, key=itemgetter(3), reverse=True)
for i, model, original_rows, would_broadcast in models:
col_name = '_c%d' % i
if not (would_broadcast or broadcast_model):
cutoff = int(original_rows * skew_broadcast_pct/100.0)
df = skewed_join(df, model, col_name, cutoff)
else:
model = (model.drop('model_count')
.withColumnRenamed('data', col_name))
model = broadcast(model) if broadcast_model else model
df = (df
.join(model, col_name, how='left')
.drop(col_name)
.withColumnRenamed('id', col_name))
return df.fillna(0, ['_c%d' % i for i in CAT_COLS])
def transform_log(df, transform_log = False):
cols = ['_c%d' % i for i in INT_COLS]
if transform_log:
for col_name in cols:
df = df.withColumn(col_name, log(df[col_name] + 3))
return df.fillna(0, cols)
def would_broadcast(spark, str_path):
sc = spark.sparkContext
config = sc._jsc.hadoopConfiguration()
path = sc._jvm.org.apache.hadoop.fs.Path(str_path)
fs = sc._jvm.org.apache.hadoop.fs.FileSystem.get(config)
stat = fs.listFiles(path, True)
sum = 0
while stat.hasNext():
sum = sum + stat.next().getLen()
sql_conf = sc._jvm.org.apache.spark.sql.internal.SQLConf()
cutoff = sql_conf.autoBroadcastJoinThreshold() * sql_conf.fileCompressionFactor()
return sum <= cutoff
def delete_data_source(spark, path):
sc = spark.sparkContext
config = sc._jsc.hadoopConfiguration()
path = sc._jvm.org.apache.hadoop.fs.Path(path)
sc._jvm.org.apache.hadoop.fs.FileSystem.get(config).delete(path, True)
def load_raw(spark, folder, day_range):
label_fields = [StructField('_c%d' % LABEL_COL, IntegerType())]
int_fields = [StructField('_c%d' % i, IntegerType()) for i in INT_COLS]
str_fields = [StructField('_c%d' % i, StringType()) for i in CAT_COLS]
schema = StructType(label_fields + int_fields + str_fields)
paths = [os.path.join(folder, 'day_%d' % i) for i in day_range]
return (spark
.read
.schema(schema)
.option('sep', '\t')
.csv(paths))
def rand_ordinal(df):
return df.withColumn('ordinal', col_of_rand_long())
def day_from_ordinal(df, num_days):
return df.withColumn('day', (col('ordinal') % num_days).cast(IntegerType()))
def day_from_input_file(df):
return df.withColumn('day', substring_index(input_file_name(), '_', -1).cast(IntegerType()))
def psudo_sort_by_day_plus(spark, df, num_days):
shuffle_parts = int(spark.conf.get('spark.sql.shuffle.partitions'))
extra_parts = int(shuffle_parts/num_days)
if extra_parts <= 0:
df = df.repartition('day')
else:
divided = (col('ordinal') / num_days).cast(LongType())
extra_ident = divided % extra_parts
df = df.repartition(col('day'), extra_ident)
return df.sortWithinPartitions('day', 'ordinal')
def load_combined_model(spark, model_folder):
path = os.path.join(model_folder, 'combined.parquet')
return spark.read.parquet(path)
def save_combined_model(df, model_folder, mode=None):
path = os.path.join(model_folder, 'combined.parquet')
df.write.parquet(path, mode=mode)
def delete_combined_model(spark, model_folder):
path = os.path.join(model_folder, 'combined.parquet')
delete_data_source(spark, path)
def load_low_mem_partial_ids(spark, model_folder):
path = os.path.join(model_folder, 'partial_ids.parquet')
return spark.read.parquet(path)
def save_low_mem_partial_ids(df, model_folder, mode=None):
path = os.path.join(model_folder, 'partial_ids.parquet')
df.write.parquet(path, mode=mode)
def delete_low_mem_partial_ids(spark, model_folder):
path = os.path.join(model_folder, 'partial_ids.parquet')
delete_data_source(spark, path)
def load_column_models(spark, model_folder, count_required):
for i in CAT_COLS:
path = os.path.join(model_folder, '%d.parquet' % i)
df = spark.read.parquet(path)
if count_required:
values = df.agg(sum('model_count').alias('sum'), count('*').alias('size')).collect()
else:
values = df.agg(sum('model_count').alias('sum')).collect()
yield i, df, values[0], would_broadcast(spark, path)
def save_column_models(column_models, model_folder, mode=None):
for i, model in column_models:
path = os.path.join(model_folder, '%d.parquet' % i)
model.write.parquet(path, mode=mode)
def save_model_size(model_size, path, write_mode):
if os.path.exists(path) and write_mode == 'errorifexists':
print('Error: model size file %s exists' % path)
sys.exit(1)
os.makedirs(os.path.dirname(os.path.abspath(path)), exist_ok=True)
with open(path, 'w') as fp:
json.dump(model_size, fp, indent=4)
_benchmark = {}
@contextmanager
def _timed(step):
start = time()
yield
end = time()
_benchmark[step] = end - start
def _parse_args():
parser = ArgumentParser()
parser.add_argument(
'--mode',
required=True,
choices=['generate_models', 'transform'])
parser.add_argument('--days', required=True)
parser.add_argument('--input_folder', required=True)
parser.add_argument('--output_folder')
parser.add_argument('--model_size_file')
parser.add_argument('--model_folder', required=True)
parser.add_argument(
'--write_mode',
choices=['overwrite', 'errorifexists'],
default='errorifexists')
parser.add_argument('--frequency_limit')
parser.add_argument('--no_numeric_log_col', action='store_true')
parser.add_argument('--low_mem', action='store_true')
parser.add_argument(
'--output_ordering',
choices=['total_random', 'day_random', 'any', 'input'],
default='total_random')
parser.add_argument(
'--output_partitioning',
choices=['day', 'none'],
default='none')
parser.add_argument('--dict_build_shuffle_parallel_per_day', type=int, default=2)
parser.add_argument('--apply_shuffle_parallel_per_day', type=int, default=25)
parser.add_argument('--skew_broadcast_pct', type=float, default=1.0)
parser.add_argument('--debug_mode', action='store_true')
args = parser.parse_args()
start, end = args.days.split('-')
args.day_range = list(range(int(start), int(end) + 1))
args.days = len(args.day_range)
return args
def _main():
args = _parse_args()
spark = SparkSession.builder.getOrCreate()
df = load_raw(spark, args.input_folder, args.day_range)
if args.mode == 'generate_models':
spark.conf.set('spark.sql.shuffle.partitions', args.days * args.dict_build_shuffle_parallel_per_day)
with _timed('generate models'):
col_counts = get_column_counts_with_frequency_limit(df, args.frequency_limit)
if args.low_mem:
save_low_mem_partial_ids(
assign_low_mem_partial_ids(col_counts),
args.model_folder,
args.write_mode)
save_combined_model(
assign_low_mem_final_ids(load_low_mem_partial_ids(spark, args.model_folder)),
args.model_folder,
args.write_mode)
if not args.debug_mode:
delete_low_mem_partial_ids(spark, args.model_folder)
else:
save_combined_model(
assign_id_with_window(col_counts),
args.model_folder,
args.write_mode)
save_column_models(
get_column_models(load_combined_model(spark, args.model_folder)),
args.model_folder,
args.write_mode)
if not args.debug_mode:
delete_combined_model(spark, args.model_folder)
if args.mode == 'transform':
with _timed('transform'):
if args.output_ordering == 'total_random':
df = rand_ordinal(df)
if args.output_partitioning == 'day':
df = day_from_ordinal(df, args.days)
elif args.output_ordering == 'day_random':
df = rand_ordinal(df)
df = day_from_input_file(df)
elif args.output_ordering == 'input':
df = df.withColumn('ordinal', monotonically_increasing_id())
if args.output_partitioning == 'day':
df = day_from_input_file(df)
else:
if args.output_partitioning == 'day':
df = day_from_input_file(df)
models = list(load_column_models(spark, args.model_folder, bool(args.model_size_file)))
if args.model_size_file:
save_model_size(
OrderedDict(('_c%d' % i, agg.size) for i, _, agg, _ in models),
args.model_size_file,
args.write_mode)
models = [(i, df, agg.sum, flag) for i, df, agg, flag in models]
df = apply_models(
df,
models,
not args.low_mem,
args.skew_broadcast_pct)
df = transform_log(df, not args.no_numeric_log_col)
if args.output_partitioning == 'day':
partitionBy = 'day'
else:
partitionBy = None
if args.output_ordering == 'total_random':
if args.output_partitioning == 'day':
df = psudo_sort_by_day_plus(spark, df, args.days)
else:
df = df.repartition('ordinal').sortWithinPartitions('ordinal')
df = df.drop('ordinal')
elif args.output_ordering == 'day_random':
df = psudo_sort_by_day_plus(spark, df, args.days)
df = df.drop('ordinal')
if args.output_partitioning != 'day':
df = df.drop('day')
elif args.output_ordering == 'input':
if args.low_mem:
df = df.orderBy('ordinal')
else:
df = df.sortWithinPartitions('ordinal')
df = df.drop('ordinal')
if args.output_partitioning != 'day':
df = df.drop('day')
df.write.parquet(
args.output_folder,
mode=args.write_mode,
partitionBy=partitionBy)
print('=' * 100)
print(_benchmark)
if __name__ == '__main__':
_main()