import os from random import randint from tqdm import tqdm from utils import ORIGINAL_DATA_BASE, read_file from my_utils import multi_thread def create(pos, neg, tgt): pos_sents = read_file(pos) neg_sents = read_file(neg) neg_length = len(neg_sents) with open(tgt, "w", encoding="utf-8") as fout: for sent in tqdm(pos_sents): first = sent.split("\t")[0] index = randint(0, neg_length - 1) pair = neg_sents[index].split("\t") pair = pair[randint(0, 1)] pair = pair.replace("\n", "") fout.write(first + "\t" + pair + "\n") def main(): # for i in range(6): # neg = os.path.join(ORIGINAL_DATA_BASE, "linux32_0{}xxxx.all".format(j)) # pos = os.path.join(ORIGINAL_DATA_BASE, "linux32_0{}xxxx.all".format(i)) file = os.path.join("../dataset/all/pos_clean") out_file = os.path.join("../dataset/all/neg_txt") os.makedirs(out_file, exist_ok=True) for i in tqdm(range(os.cpu_count()), total=os.cpu_count() ): j = (i + 1) % os.cpu_count() pos = os.path.join(file, f"inst.{i}.pos.txt.clean") neg = os.path.join(file, f"inst.{j}.pos.txt.clean") tgt = os.path.join(out_file, f"inst.{i}.neg.txt") create(pos, neg, tgt) if __name__ == "__main__": main()