From 4073eb07d0c564a7edf56fe3af428e2ab44eb19f Mon Sep 17 00:00:00 2001 From: Raymond Li Date: Tue, 1 Nov 2022 11:55:01 -0400 Subject: [PATCH] add filtering code --- data_analysis/filtering/filtering.py | 213 +++++++++++++++++++++++++ data_analysis/util/dataset_sharding.py | 29 ++++ 2 files changed, 242 insertions(+) create mode 100644 data_analysis/filtering/filtering.py create mode 100644 data_analysis/util/dataset_sharding.py diff --git a/data_analysis/filtering/filtering.py b/data_analysis/filtering/filtering.py new file mode 100644 index 0000000..6b5a7a2 --- /dev/null +++ b/data_analysis/filtering/filtering.py @@ -0,0 +1,213 @@ +from dataclasses import dataclass, field +from enum import Enum +from glob import glob +import json +import os +import time +from typing import Tuple + +from datasets import load_dataset +from util.dataset_sharding import shard_dataset + + +@dataclass +class FilterContentArguments: + # Remove files that have Average line length > 100 + max_avg_line_length: int = field( + default=100, + metadata={ + "help": "Maximum average-line-length. File that have a larger average line length will be discarded." + }, + ) + # Remove files that have Maximum line length > 1000 + max_line_length: int = field( + default=1000, + metadata={ + "help": "Maximum line length. Files that have a larger maximum line length will be discarded" + }, + ) + # Remove files that have a fraction of alphanumeric characters < 0.25 + min_alphanum_fraction: float = field( + default=0.25, metadata={"help": "Minimum fraction of alphanumeric characters."} + ) + check_auto_generated: bool = field( + default=True, metadata={"help": "Whether to check for autogenerated files."} + ) + + +class FilterContentMeta: + def __init__(self) -> None: + self.per_language_stats = dict() + self.per_language_filter_reasons = dict() + + def to_dict(self): + return { + "per_language_stats": self.per_language_stats, + "per_language_filter_reasons": self.per_language_filter_reasons, + } + + def update_language_stats(self, lang: str, size: int): + if lang not in self.per_language_stats: + self.per_language_stats[lang] = {"num_files": 0, "total_size": 0} + self.per_language_stats[lang]["num_files"] += 1 + self.per_language_stats[lang]["total_size"] += size + + def update_language_filter_reason(self, lang: str, filter_reason: str): + if lang not in self.per_language_filter_reasons: + self.per_language_filter_reasons[lang] = {} + if filter_reason not in self.per_language_filter_reasons[lang]: + self.per_language_filter_reasons[lang][filter_reason] = 0 + self.per_language_filter_reasons[lang][filter_reason] += 1 + + +# Inspired by https://github.com/huggingface/transformers/blob/master/examples/research_projects/codeparrot/scripts/preprocessing.py +def is_autogenerated(file_content, scan_width=5): + """Check if file is autogenerated by looking for keywords in the first few lines of the file.""" + keywords = ["auto-generated", "autogenerated", "automatically generated"] + lines = file_content.splitlines() + for _, line in zip(range(scan_width), lines): + for keyword in keywords: + if keyword in line.lower(): + return True + return False + + +LANGUAGE_COL = "lang" +SIZE_COL = "size" +SHARD_SIZE = 1000 << 20 # 1GB + + +def add_dict(dict1: dict, dict2: dict) -> None: + """ + Add the values of dict2 to dict1. All values must be int, float or dictionaries that also verify this condition. + Will modify dict1 and return None + """ + for key, value in dict2.items(): + if isinstance(value, (int, float)): + if key not in dict1: + dict1[key] = 0 + dict1[key] += value + elif isinstance(value, dict): + if key not in dict1: + dict1[key] = {} + assert isinstance(dict1[key], dict) + add_dict(dict1[key], value) + else: + raise ValueError(f"Invalid type for key/value {key}: {value}") + + +def aggregate_meta(tmp_meta_dir: str): + res = {} + for file in glob(f"{tmp_meta_dir}/*-meta.json"): + with open(file, "r") as f: + meta = json.load(f) + add_dict(res, meta) + return res + + +def filter_sample( + sample, filter_content_args: FilterContentArguments +) -> Tuple[bool, str]: + """ + Returns a tuple (should_include, filter_reason): + - `should_include`: whether the sample should be kept in the dataset + - `filter_reason`: if the sample is to be discarded, the reason why it was filtered out. + """ + if ( + filter_content_args.max_avg_line_length > 0 + and sample["avg_line_length"] > filter_content_args.max_avg_line_length + ): + return False, "avg_line_length" + + # Max line length + if ( + filter_content_args.max_line_length > 0 + and sample["max_line_length"] > filter_content_args.max_line_length + ): + return False, "max_line_length" + + # Alphanumeric characters + if ( + filter_content_args.min_alphanum_fraction > 0 + and sample["alphanum_fraction"] < filter_content_args.min_alphanum_fraction + ): + return False, "alphanum_fraction" + + # Auto-generated files + if filter_content_args.check_auto_generated and is_autogenerated(sample["content"]): + return False, "autogenerated" + + return True, "" + + +def filter_batch( + batch: dict, idx, filter_content_args: FilterContentArguments, tmp_meta_dir: str +): + meta = FilterContentMeta() + features = batch.keys() + res = {k: [] for k in features} + for sample in zip(*[batch[k] for k in features]): + sample = {k: v for k, v in zip(features, sample)} + should_include, filter_reason = filter_sample(sample, filter_content_args) + if not should_include: + meta.update_language_filter_reason(sample[LANGUAGE_COL], filter_reason) + else: + meta.update_language_stats(sample[LANGUAGE_COL], sample[SIZE_COL]) + # Add to output + for k in features: + res[k].append(sample[k]) + + # Record Meta + with open(os.path.join(tmp_meta_dir, f"{idx[0]}-{idx[-1]}-meta.json"), "w") as f: + json.dump(meta.to_dict(), f) + return res + + +def filter_dataset( + dataset, filter_content_args: FilterContentArguments, tmp_meta_dir: str +): + filtered = dataset.map( + filter_batch, + batched=True, + with_indices=True, + num_proc=NUM_PROC, + fn_kwargs={ + "filter_content_args": filter_content_args, + "tmp_meta_dir": tmp_meta_dir, + }, + load_from_cache_file=False, + ) + return filtered + + +NUM_PROC = 64 + + +def main(): + dataset_name = "/data/stack_python/data" + output_dir = "/data/filtering/test_dataset_stack" + tmp_meta_dir = f"{output_dir}/tmp/meta" + data_dir = f"{output_dir}/data" + + ds = load_dataset( + dataset_name, split="train", use_auth_token=True, chunksize=40 << 20 + ) + filter_content_args = FilterContentArguments() + + os.makedirs(tmp_meta_dir) + os.makedirs(data_dir) + + filtered = filter_dataset(ds, filter_content_args, tmp_meta_dir) + print(len(filtered)) + # Dump meta + meta = aggregate_meta(tmp_meta_dir) + print(meta) + with open(os.path.join(output_dir, "meta.json"), "w") as f: + json.dump(meta, f) + + # Save shards + shard_dataset(ds, SHARD_SIZE, data_dir, num_proc=NUM_PROC) + + +if __name__ == "__main__": + main() diff --git a/data_analysis/util/dataset_sharding.py b/data_analysis/util/dataset_sharding.py new file mode 100644 index 0000000..4414b6e --- /dev/null +++ b/data_analysis/util/dataset_sharding.py @@ -0,0 +1,29 @@ +import time +from multiprocessing import Pool +from tqdm import tqdm + + +def save_shard(shard_tuple): + """Save shard""" + filename, shard = shard_tuple + # use to_json instead to save as json file + shard.to_parquet(filename) + + +def shard_dataset(ds, shard_size, output_dir, num_proc): + if ds._indices is not None: + dataset_nbytes = ds.data.nbytes * len(ds._indices) / len(ds.data) + else: + dataset_nbytes = ds.data.nbytes + num_shards = int(dataset_nbytes / shard_size) + 1 + print(f"Number of shards: {num_shards}") + + print("sharding the dataset") + t_start = time.time() + shards = (ds.shard(num_shards=num_shards, index=i, contiguous=True) for i in range(num_shards)) + # use f"{OUT_PATH}/data/train-{index:05d}-of-{num_shards:05d}.json" instead for json files + filenames = (f"{output_dir}/train-{index:05d}-of-{num_shards:05d}.parquet" for index in range(num_shards)) + + with Pool(num_proc) as p: + list(tqdm(p.imap_unordered(save_shard, zip(filenames, shards), chunksize=4), total=num_shards)) + print(f"Time to save dataset: {time.time()-t_start:.2f}")