2021-06-06 20:50:36 +08:00
|
|
|
import os
|
|
|
|
from random import randint
|
|
|
|
|
|
|
|
from tqdm import tqdm
|
|
|
|
|
|
|
|
from utils import ORIGINAL_DATA_BASE, read_file
|
|
|
|
|
2024-04-11 16:43:57 +08:00
|
|
|
from my_utils import multi_thread
|
|
|
|
|
2021-06-06 20:50:36 +08:00
|
|
|
|
|
|
|
def create(pos, neg, tgt):
|
|
|
|
pos_sents = read_file(pos)
|
|
|
|
neg_sents = read_file(neg)
|
|
|
|
neg_length = len(neg_sents)
|
|
|
|
with open(tgt, "w", encoding="utf-8") as fout:
|
|
|
|
for sent in tqdm(pos_sents):
|
|
|
|
first = sent.split("\t")[0]
|
|
|
|
index = randint(0, neg_length - 1)
|
2024-04-11 16:43:57 +08:00
|
|
|
pair = neg_sents[index].split("\t")
|
|
|
|
pair = pair[randint(0, 1)]
|
|
|
|
pair = pair.replace("\n", "")
|
2021-06-06 20:50:36 +08:00
|
|
|
fout.write(first + "\t" + pair + "\n")
|
|
|
|
|
|
|
|
|
|
|
|
def main():
|
|
|
|
# for i in range(6):
|
2024-04-11 16:43:57 +08:00
|
|
|
# neg = os.path.join(ORIGINAL_DATA_BASE, "linux32_0{}xxxx.all".format(j))
|
|
|
|
# pos = os.path.join(ORIGINAL_DATA_BASE, "linux32_0{}xxxx.all".format(i))
|
|
|
|
file = os.path.join("../dataset/all/pos_clean")
|
|
|
|
out_file = os.path.join("../dataset/all/neg_txt")
|
|
|
|
os.makedirs(out_file, exist_ok=True)
|
|
|
|
for i in tqdm(range(os.cpu_count()), total=os.cpu_count() ):
|
|
|
|
j = (i + 1) % os.cpu_count()
|
|
|
|
pos = os.path.join(file, f"inst.{i}.pos.txt.clean")
|
|
|
|
neg = os.path.join(file, f"inst.{j}.pos.txt.clean")
|
|
|
|
tgt = os.path.join(out_file, f"inst.{i}.neg.txt")
|
2021-06-06 20:50:36 +08:00
|
|
|
create(pos, neg, tgt)
|
|
|
|
if __name__ == "__main__":
|
|
|
|
main()
|