From 21e8bad1d92f264ed09c04ed92325ef71c5fa204 Mon Sep 17 00:00:00 2001 From: amackillop Date: Thu, 2 Jul 2020 17:33:18 -0400 Subject: [PATCH] python stuff --- .gitignore | 3 +- input.py | 5 + interactive_predict.py | 28 ++-- preprocess_python.sh | 61 ++++++++ python_extractor/.gitignore | 5 + python_extractor/__init__.py | 0 python_extractor/extract.py | 123 ++++++++++++++++ python_extractor/extractor.py | 260 ++++++++++++++++++++++++++++++++++ train.sh | 7 +- 9 files changed, 480 insertions(+), 12 deletions(-) create mode 100644 input.py create mode 100755 preprocess_python.sh create mode 100644 python_extractor/.gitignore create mode 100644 python_extractor/__init__.py create mode 100644 python_extractor/extract.py create mode 100644 python_extractor/extractor.py diff --git a/.gitignore b/.gitignore index a917dee..f67da6a 100644 --- a/.gitignore +++ b/.gitignore @@ -4,4 +4,5 @@ **/data/** **/.idea/** *.tar.gz -**/log.txt \ No newline at end of file +**/log.txt +__pycache__ \ No newline at end of file diff --git a/input.py b/input.py new file mode 100644 index 0000000..d0b10fc --- /dev/null +++ b/input.py @@ -0,0 +1,5 @@ +def fact(n): + if n == 0: + return 1 + else: + return n * fact(n - 1) diff --git a/interactive_predict.py b/interactive_predict.py index 78aac6c..e22a4b1 100644 --- a/interactive_predict.py +++ b/interactive_predict.py @@ -1,12 +1,11 @@ import traceback from common import common -from extractor import Extractor +from python_extractor.extractor import Extractor SHOW_TOP_CONTEXTS = 10 MAX_PATH_LENGTH = 8 MAX_PATH_WIDTH = 2 -JAR_PATH = 'JavaExtractor/JPredict/target/JavaExtractor-0.0.1-SNAPSHOT.jar' class InteractivePredictor: @@ -16,9 +15,7 @@ def __init__(self, config, model): model.predict([]) self.model = model self.config = config - self.path_extractor = Extractor(config, - jar_path=JAR_PATH, - max_path_length=MAX_PATH_LENGTH, + self.path_extractor = Extractor(max_path_length=MAX_PATH_LENGTH, max_path_width=MAX_PATH_WIDTH) def read_file(self, input_filename): @@ -26,7 +23,8 @@ def read_file(self, input_filename): return file.readlines() def predict(self): - input_filename = 'Input.java' + # input_filename = 'Input.java' + input_filename = 'input.py' print('Starting interactive prediction...') while True: print( @@ -36,11 +34,20 @@ def predict(self): print('Exiting...') return try: - predict_lines, hash_to_string_dict = self.path_extractor.extract_paths(input_filename) + predict_lines = list(path.strip() for path in self.path_extractor.extract_paths(input_filename)) + contexts = predict_lines[0].split() + space_padding = ' ' * (self.config.MAX_CONTEXTS - len(contexts) + 1) + predict_lines[0] = ' '.join(contexts) + space_padding + print(predict_lines) except ValueError as e: print(e) continue - raw_prediction_results = self.model.predict(predict_lines) + hash_to_string_dict = UnitDict() + try: + raw_prediction_results = self.model.predict(predict_lines) + except Exception as exc: + print(exc) + continue method_prediction_results = common.parse_prediction_results( raw_prediction_results, hash_to_string_dict, self.model.vocabs.target_vocab.special_words, topk=SHOW_TOP_CONTEXTS) @@ -55,3 +62,8 @@ def predict(self): if self.config.EXPORT_CODE_VECTORS: print('Code vector:') print(' '.join(map(str, raw_prediction.code_vector))) + +class UnitDict(dict): + + def __getitem__(self, key): + return key diff --git a/preprocess_python.sh b/preprocess_python.sh new file mode 100755 index 0000000..936bf52 --- /dev/null +++ b/preprocess_python.sh @@ -0,0 +1,61 @@ +#!/usr/bin/env bash +########################################################### +# Change the following values to preprocess a new dataset. +# TRAIN_DIR, VAL_DIR and TEST_DIR should be paths to +# directories containing sub-directories with .java files +# each of {TRAIN_DIR, VAL_DIR and TEST_DIR} should have sub-dirs, +# and data will be extracted from .java files found in those sub-dirs). +# DATASET_NAME is just a name for the currently extracted +# dataset. +# MAX_CONTEXTS is the number of contexts to keep for each +# method (by default 200). +# WORD_VOCAB_SIZE, PATH_VOCAB_SIZE, TARGET_VOCAB_SIZE - +# - the number of words, paths and target words to keep +# in the vocabulary (the top occurring words and paths will be kept). +# The default values are reasonable for a Tesla K80 GPU +# and newer (12 GB of board memory). +# NUM_THREADS - the number of parallel threads to use. It is +# recommended to use a multi-core machine for the preprocessing +# step and set this value to the number of cores. +# PYTHON - python3 interpreter alias. +# TRAIN_DIR=../data/java-small/test +# VAL_DIR=../data/java-small/test +# TEST_DIR=../data/java-small/test +DATASET_NAME=python_20k +MAX_CONTEXTS=200 +WORD_VOCAB_SIZE=1301136 +PATH_VOCAB_SIZE=911417 +TARGET_VOCAB_SIZE=261245 +NUM_THREADS=64 +PYTHON=python +########################################################### +REPO_DIR=repos +CONTEXTS_DIR=output +DATA_DIR=data + +TRAIN_DATA_FILE=${CONTEXTS_DIR}/train/path_contexts.csv +VAL_DATA_FILE=${CONTEXTS_DIR}/val/path_contexts.csv +TEST_DATA_FILE=${CONTEXTS_DIR}/test/path_contexts.csv + +mkdir -p ${DATA_DIR}/${DATASET_NAME} + +echo "Extracting paths..." +${PYTHON} python_extractor/extract.py --in_dir ${REPO_DIR} --out_dir ${CONTEXTS_DIR} --max_path_length 8 --max_path_width 2 --max_workers ${NUM_THREADS} + +TARGET_HISTOGRAM_FILE=../data/${DATASET_NAME}/${DATASET_NAME}.histo.tgt.c2v +ORIGIN_HISTOGRAM_FILE=../data/${DATASET_NAME}/${DATASET_NAME}.histo.ori.c2v +PATH_HISTOGRAM_FILE=../data/${DATASET_NAME}/${DATASET_NAME}.histo.path.c2v + +echo "Creating histograms from the training data" +cat ${TRAIN_DATA_FILE} | cut -d' ' -f1 | awk '{n[$0]++} END {for (i in n) print i,n[i]}' > ${TARGET_HISTOGRAM_FILE} +cat ${TRAIN_DATA_FILE} | cut -d' ' -f2- | tr ' ' '\n' | cut -d',' -f1,3 | tr ',' '\n' | awk '{n[$0]++} END {for (i in n) print i,n[i]}' > ${ORIGIN_HISTOGRAM_FILE} +cat ${TRAIN_DATA_FILE} | cut -d' ' -f2- | tr ' ' '\n' | cut -d',' -f2 | awk '{n[$0]++} END {for (i in n) print i,n[i]}' > ${PATH_HISTOGRAM_FILE} + +${PYTHON} preprocess.py --train_data ${TRAIN_DATA_FILE} --test_data ${TEST_DATA_FILE} --val_data ${VAL_DATA_FILE} \ + --max_contexts ${MAX_CONTEXTS} --word_vocab_size ${WORD_VOCAB_SIZE} --path_vocab_size ${PATH_VOCAB_SIZE} \ + --target_vocab_size ${TARGET_VOCAB_SIZE} --word_histogram ${ORIGIN_HISTOGRAM_FILE} \ + --path_histogram ${PATH_HISTOGRAM_FILE} --target_histogram ${TARGET_HISTOGRAM_FILE} --output_name ../data/${DATASET_NAME}/${DATASET_NAME} + +# If all went well, the raw data files can be deleted, because preprocess.py creates new files +# with truncated and padded number of paths for each example. +rm ${TARGET_HISTOGRAM_FILE} ${ORIGIN_HISTOGRAM_FILE} ${PATH_HISTOGRAM_FILE} diff --git a/python_extractor/.gitignore b/python_extractor/.gitignore new file mode 100644 index 0000000..4a87948 --- /dev/null +++ b/python_extractor/.gitignore @@ -0,0 +1,5 @@ +.* +__pycache__/ +main.json +*.c2v +!.git* \ No newline at end of file diff --git a/python_extractor/__init__.py b/python_extractor/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/python_extractor/extract.py b/python_extractor/extract.py new file mode 100644 index 0000000..365a760 --- /dev/null +++ b/python_extractor/extract.py @@ -0,0 +1,123 @@ +from __future__ import annotations + + +from argparse import ArgumentParser +from concurrent.futures import ThreadPoolExecutor +from pebble import ProcessPool +import concurrent.futures as cf +from functools import partial +from pathlib import PosixPath +import os +import math +import itertools as it +from glob import iglob, glob +import subprocess as sp +from tqdm import tqdm +from typing import Iterator, List, Iterable, Tuple, Union + +import re + +from extractor import Extractor + + +Path = Union[str, PosixPath] + + +def process(fname: str, max_length: int, max_wdith: int) -> List[str]: + extractor = Extractor(max_path_length=8, max_path_width=2) + try: + paths = extractor.extract_paths(fname) + except (ValueError, SyntaxError, RecursionError): + return list() + return list(paths) + + +def write_lines(fname: str, lines: Iterable[str]) -> None: + with open(fname, "a", encoding="ISO-8859-1") as stream: + stream.writelines(map(mask_method_name, lines)) + + +def mask_method_name(line: str) -> str: + method_name, _, _ = line.partition(" ") + pattern = re.compile(re.escape(f" {method_name},")) + return pattern.sub(" METHOD_NAME,", line) + + +def to_str_path(list_path: List[str]) -> str: + return f"{list_path[0]},{'|'.join(list_path[1:-1])},{list_path[-1]}" + + +def make_posix_path(path: Path) -> PosixPath: + return PosixPath(path) if isinstance(path, str) else path + + +def concatenate_path_conext_files(mined_dir_path: Path) -> None: + mined_dir_path = make_posix_path(mined_dir_path) + dtq = tqdm(["train", "test", "val"], desc="concatenating ast path conext files") + for _dir in dtq: + file_dir = str(mined_dir_path / f"{_dir}") + concate_sh = f"cat {file_dir}/*.c2v > {file_dir}/path_contexts.csv" + sp.run(concate_sh, shell=True, check=True) + + for f in iglob(str(mined_dir_path / "*/*.c2v")): + os.remove(f) + + print("Done concatenating all path_contexts from AST miner to a single file") + + +def source_files(data_dir: str): + for fname in iglob(f"{data_dir}/*/**/[!setup]*.py", recursive=True): + if os.path.isfile(fname) and not fname.startswith("test"): + yield fname + + +def chunker(iterable, n, fillvalue=None): + "Collect data into fixed-length chunks or blocks" + # chunker('ABCDEFG', 3, 'x') --> ABC DEF Gxx" + args = [iter(iterable)] * n + return it.zip_longest(*args, fillvalue=fillvalue) + + +if __name__ == "__main__": + parser = ArgumentParser() + parser.add_argument("-maxlen", "--max_path_length", dest="max_path_length", required=False, default=8) + parser.add_argument("-maxwidth", "--max_path_width", dest="max_path_width", required=False, default=2) + parser.add_argument("-workers", "--max_workers", dest="max_workers", required=False, default=None) + parser.add_argument("-in_dir", "--in_dir", dest="in_dir", required=True) + parser.add_argument("-out_dir", "--out_dir", dest="out_dir", required=True) + # parser.add_argument("-file", "--file", dest="file", required=False) + args = parser.parse_args() + + TIMEOUT = 60 * 10 + MAX_WORKERS = int(args.max_workers) + MAX_LENGTH = args.max_path_length + MAX_WIDTH = args.max_path_width + REPOS = args.in_dir + OUTPUT = args.out_dir + + writes = list() + futures = list() + with ProcessPool(max_workers=MAX_WORKERS) as pool, ThreadPoolExecutor( + max_workers=1 + ) as writer: + futures = { + pool.schedule(process, args=[fname, MAX_LENGTH, MAX_WIDTH], timeout=TIMEOUT): fname + for fname in source_files(REPOS) + } + + for future in tqdm(cf.as_completed(futures), total=len(futures)): + fname = futures[future] + splitted = fname.split("/") + project = splitted[2] + bin_ = splitted[1] + c2v_file = f"{OUTPUT}/{bin_}/{project}.c2v" + try: + paths = future.result() + except cf.TimeoutError: + continue + if paths: + writes.append(writer.submit(partial(write_lines, c2v_file), paths)) + + cf.wait(writes) + + concatenate_path_conext_files(mined_dir_path=OUTPUT) diff --git a/python_extractor/extractor.py b/python_extractor/extractor.py new file mode 100644 index 0000000..e7a83ba --- /dev/null +++ b/python_extractor/extractor.py @@ -0,0 +1,260 @@ +from __future__ import annotations + +import ast +from ast import NodeVisitor, increment_lineno +from contextlib import contextmanager +import dataclasses as dc +import collections +import itertools as it +import re +from typing import ( + Iterable, + List, + Iterator, + Union, + Dict, + List, + Optional, + TypeVar, +) + +A = TypeVar("A") + + +class Extractor(NodeVisitor): + def __init__(self, max_path_length, max_path_width): + self._stack: List[JsonTree] = list() + self._func_name: str = str() + self.tree: List[JsonTree] = list() + self._json_tree: List[dict] = list() + self.paths: List[List[JsonTree]] = list() + self.paths_map: Dict[str, List[List[str]]] = dict() + self.MAX_DEPTH = max_path_length + self.MAX_WIDTH = max_path_width + self._build_json = False + self.replace_pattern = re.compile("[0-9_,]") + + def add_path(self, path: List[JsonTree]): + """ + The hyperparameters for filtering out some paths could be applied here. + + """ + + def transform(tree: JsonTree) -> str: + return tree.node_type if isinstance(tree, JsonNode) else tree.value + + for prev_path in self.paths: + merged = self.merge(prev_path, path) + if merged: + # For some reason we must cast to list here or the last token goes missing + self.paths_map[self._func_name].append(list(map(transform, merged))) + + self.paths.append(list(path)) + + def merge( + self, left_path: List[JsonTree], right_path: List[JsonTree] + ) -> Optional[Iterator[JsonTree]]: + """ + V + / \ + + Once we have vertex, the index of left and right canot be more than W apart? + + """ + if self.fails_depth_check(left_path, right_path): + return None + + lefts, rights = iter(left_path), iter(right_path) + for left, right in zip(lefts, rights): + if id(left) == id(right): + vertex = left + else: + break + + if self.fails_width_check(vertex.children, left, right): + return None + + merged = it.chain(reversed(list(lefts)), [left, vertex, right], rights) + return merged + + def fails_depth_check(self, left_path: Iterable, right_path: Iterable) -> bool: + length = len(set(map(id, left_path)) ^ set(map(id, right_path))) + 1 + return length > self.MAX_DEPTH + + def fails_width_check(self, children: List[A], left: A, right: A) -> bool: + left_index = children.index(left) + right_index = children.index(right) + + return right_index - left_index > self.MAX_WIDTH + + def _parse(self, fname: str) -> List[dict]: + """I only care about function/method definitions for now""" + + with open(fname, "r", encoding="ISO-8859-1") as stream: + tree = ast.parse(stream.read()) + + for node in ast.walk(tree): + if isinstance(node, (ast.FunctionDef, ast.AsyncFunctionDef)): + self._func_name = self.clean(node.name) + self.paths_map[self._func_name] = [] + self.visit(node) + self.paths = list() + + return self._json_tree + + def clean(self, token: str) -> str: + token = self.to_snake_case(token) + return self.replace_pattern.sub("|", token).strip("|") + + @staticmethod + def to_snake_case(token: str) -> str: + splitted = re.sub( + "([A-Z][a-z]+)", r" \1", re.sub("([A-Z]+)", r" \1", token) + ).split() + return "_".join(word.strip('_').lower() for word in splitted) + + def extract_paths(self, fname: str) -> Iterator[str]: + # Can perhaps take the transformation function as a parameter for more flexibility + def transform(path_contexts: Iterable[Iterable[str]]) -> str: + def to_str_path(path_context: Iterable[str]) -> str: + context = list(path_context) + return f"{self.clean(context[0])},{'|'.join(context[1:-1])},{self.clean(context[-1])}" + + str_paths = map(to_str_path, path_contexts) + return f"{' '.join(str_paths)}".encode("unicode_escape").decode( + "ISO-8859-1" + ) + + self._parse(fname) + + paths = ( + f"{func_name} {transform(contexts)}\n" + for func_name, contexts in self.paths_map.items() + ) + self.paths_map = dict() + return paths + + def to_json(self, fname: str) -> List[dict]: + """Parse the syntax tree and output to json + + :param fname: The filename. + :type fname: str + :return: A json representation of the syntax tree. + :rtype: dict + """ + self._build_json = True + json_tree = self._parse(fname) + self._json_tree = list() + self._build_json = False + return json_tree + + def visit(self, node: ast.AST): + if isinstance( + node, (ast.boolop, ast.cmpop, ast.unaryop, ast.operator, ast.expr_context) + ): + return + if isinstance(node, (ast.BoolOp, ast.BinOp, ast.UnaryOp, ast.AugAssign)): + self.visit_Op(node) + elif isinstance(node, (ast.Break, ast.Continue)): + self.visit_Break_Continue(node) + elif isinstance(node, (ast.FunctionDef, ast.AsyncFunctionDef)): + self.visit_Function(node) + else: + super().visit(node) + + def generic_visit(self, node: ast.AST): + """Called if no explicit visitor function exists for a node.""" + json_node = JsonNode(type(node).__name__) + with self.push_to_stack(json_node): + super().generic_visit(node) + + def visit_Function(self, node: Union[ast.FunctionDef, ast.AsyncFunctionDef]): + method_name = JsonLeaf("MethodName", node.name) + if ast.get_docstring(node): + node.body = node.body[1:] + json_node = JsonNode(type(node).__name__) + with self.push_to_stack(json_node): + with self.push_to_stack(method_name): + self.add_path(self._stack) + super().generic_visit(node) + self.tree.append(json_node) + if self._build_json: + self._json_tree.append(json_node.to_dict()) + + def visit_arg(self, node: ast.arg): + arg_name = JsonLeaf("ArgName", node.arg) + json_node = JsonNode(type(node).__name__) + with self.push_to_stack(json_node): + with self.push_to_stack(arg_name): + self.add_path(self._stack) + super().generic_visit(node) + + def visit_Op(self, node: Union[ast.UnaryOp, ast.BoolOp, ast.BinOp, ast.AugAssign]): + operator = type(node.op).__name__ + node_type = "_".join([type(node).__name__, operator]) + json_node = JsonNode(node_type) + with self.push_to_stack(json_node): + super().generic_visit(node) + + def visit_Compare(self, node: ast.Compare): + json_node = JsonNode("Compare_" + type(node.ops[0]).__name__) + with self.push_to_stack(json_node): + super().generic_visit(node) + + def visit_Name(self, node: ast.Name): + json_node = JsonLeaf(type(node).__name__, node.id) + with self.push_to_stack(json_node): + self.add_path(self._stack) + + def visit_Constant(self, node: ast.Constant): + json_node = JsonLeaf( + type(node).__name__, + str(node.value).encode("unicode_escape").decode().replace(" ", ""), + ) + with self.push_to_stack(json_node): + self.add_path(self._stack) + + def visit_Break_Continue(self, node: Union[ast.Break, ast.Continue]): + json_node = JsonLeaf(type(node).__name__, type(node).__name__) + with self.push_to_stack(json_node): + self.add_path(self._stack) + + def add_to_tree(self, json_tree: JsonTree): + current_node = self._stack.pop() + if isinstance(current_node, JsonNode): + current_node.children.append(json_tree) + else: + raise RuntimeError("JsonLeaf node left on stack!") + self._stack.append(current_node) + + @contextmanager + def push_to_stack(self, json_node: JsonTree): + if self._stack: + self.add_to_tree(json_node) + self._stack.append(json_node) + try: + yield + finally: + self._stack.pop() + + +class _JsonTree: + """""" + + def to_dict(self): + return dc.asdict(self) + + +@dc.dataclass(frozen=True) +class JsonNode(_JsonTree): + node_type: str + children: List[_JsonTree] = dc.field(default_factory=list) + + +@dc.dataclass(frozen=True) +class JsonLeaf(_JsonTree): + node_type: str + value: str + + +JsonTree = Union[JsonNode, JsonLeaf] diff --git a/train.sh b/train.sh index e74fad6..c574de0 100644 --- a/train.sh +++ b/train.sh @@ -6,13 +6,14 @@ # test_data: by default, points to the validation set, since this is the set that # will be evaluated after each training iteration. If you wish to test # on the final (held-out) test set, change 'val' to 'test'. -type=java14m -dataset_name=java14m +type=python_20k +dataset_name=python_20k data_dir=data/${dataset_name} data=${data_dir}/${dataset_name} test_data=${data_dir}/${dataset_name}.val.c2v model_dir=models/${type} +PYTHON=python mkdir -p ${model_dir} set -e -python3 -u code2vec.py --data ${data} --test ${test_data} --save ${model_dir}/saved_model +${PYTHON} -u code2vec.py --data ${data} --test ${test_data} --save ${model_dir}/saved_model