# -*- coding: utf-8 -*-
import json
import os
import random
from tqdm import tqdm
random.seed(2021)
dict_path_ = './data_dict/company_list.txt'
path_in_ = './dataset_new'
path_out_train_ = './train_dataset/36Kr_ner/train.tsv'
path_out_test_ = './train_dataset/36Kr_ner/test.tsv'
class DataToSample(object):
"""
Create training sample data from raw data
"""
def __init__(self, dict_path):
self.dict_ind = {}
self.dict_path = dict_path
with open(self.dict_path, 'r', encoding='utf-8') as fp:
for line in fp.readlines():
line_new = json.loads(line)
self.dict_ind[line_new['product']] = line_new['hangye1']
def index_str(self, str1, str2):
"""
Find all positions where the specified string str1 contains the specified substring str2,
Return as a list
"""
length2 = len(str2)
length1 = len(str1)
index_str2 = []
start_index = 0
while str2 in str1[start_index:]:
index_tmp = str1.index(str2, start_index, length1)
index_str2.append((index_tmp, index_tmp + len(str2)))
start_index = (index_tmp + length2)
return index_str2
def creat_sample(self, path_in, path_out_train, path_out_test):
"""
Create training sample data
"""
sample_data = []
files = os.listdir(path_in)
for file in tqdm(files):
file_path = path_in + '/' + file
with open(file_path, 'r', encoding='utf-8') as fp:
js = json.load(fp)
text = js['text']
text_label = ['O'] * len(text)
for word in self.dict_ind.keys():
if word in text:
index = self.index_str(text, word)
if index:
for ind in index:
if '-ORG-' in text_label[ind[0]: ind[1]]: # Prevent cross labeling
continue
text_label[ind[0]] = 'B-' + 'ORG-' + self.dict_ind[word]
for ind_i in range(ind[0] + 1, ind[1]):
text_label[ind_i] = 'I-' + 'ORG-' + self.dict_ind[word]
text_new = '\002'.join(list(text))
text_label_new = '\002'.join(text_label)
sample_data.append(text_new + '\t' + text_label_new + '\n')
random.shuffle(sample_data)
train_num = int(len(sample_data) * 9 / 10)
with open(path_out_train, 'w', encoding='utf-8') as fp:
fp.writelines(sample_data[: train_num])
with open(path_out_test, 'w', encoding='utf-8') as fp:
fp.writelines(sample_data[train_num:])
print('completed')
if __name__ == '__main__':
data_to_sample = DataToSample(dict_path_)
data_to_sample.creat_sample(path_in_, path_out_train_, path_out_test_)
|