RT
import numpy as np
import sys
import time
import threading
from queue import Queue
np.random.seed(1)
READ_BATCH_SIZE=25
WRITE_BATCH_SIZE=25
r_q = Queue(READ_BATCH_SIZE)
w_q = Queue(WRITE_BATCH_SIZE)
class ProducerThread(threading.Thread):
def __init__(self, pin_cid3_vec_path, batch_size=100, name='read_producer'):
super(ProducerThread, self).__init__()
self.pin_cid3_vec_path = pin_cid3_vec_path
self.batch_size=batch_size
self.name = name
def run(self):
pin_list = []
pin_vec_list = []
num = 0
f = open(self.pin_cid3_vec_path, 'r')
for line in f:
buf = line.strip().split(',')
if len(buf) != 101:
continue
num = num + 1
pin = buf[0]
pin_vec = np.array(buf[1:], dtype=np.float32)
pin_list.append(pin)
pin_vec_list.append(pin_vec)
if num % self.batch_size == 0:
r_q.put((pin_list, pin_vec_list))
print('Read %d' % num)
pin_list = []
pin_vec_list = []
time.sleep(0.1)
f.close()
if len(pin_list) > 0:
r_q.put((pin_list, pin_vec_list))
print('Read %d' % num)
pin_list = []
pin_vec_list = []
for i in range(READ_BATCH_SIZE):
r_q.put(None)
class ComputeConsumerThread(threading.Thread):
def __init__(self, cid3_vec_path, top_K):
super(ComputeConsumerThread, self).__init__()
self.cid3_vec = get_cid3_vec(cid3_vec_path)
self.top_K = int(top_K)
def run(self):
while True:
queue_item = r_q.get()
if queue_item == None:
w_q.put(None)
break
pin_list, pin_vec_list = queue_item
cid3_ref = np.dot(np.array(pin_vec_list), self.cid3_vec)
sorted_idx = np.argsort(-cid3_ref, axis=1)
out_lines = []
for i in range(len(pin_list)):
pin = pin_list[i]
cid3_pref = ','.join(sorted_idx[i][:self.top_K].astype(str))
out_lines.append(pin + '\t' + cid3_pref)
w_q.put('\n'.join(out_lines) + '\n')
time.sleep(0.1)
class WriteConsumerThread(threading.Thread):
def __init__(self, output_path, n_compute_threads, thread_id):
super(WriteConsumerThread, self).__init__()
self.output_path = output_path
self.n_compute_threads = n_compute_threads
self.thread_id = thread_id
def run(self):
num = 0
fw = open(self.output_path, 'w')
flush_count = 0
while True:
out_lines = w_q.get()
if out_lines == None:
num = num + 1
if num >= self.n_compute_threads:
break
continue
else:
fw.write(out_lines)
flush_count = flush_count + 1
if flush_count % 10 == 0:
fw.flush()
print('Write')
time.sleep(0.1)
fw.flush()
fw.close()
def get_cid3_vec(cid3_vec_path):
cid3_vec = []
with open(cid3_vec_path, 'r') as f:
for line in f:
cid3_vec.append(line.strip().split(' ')[1:])
return np.transpose(np.array(cid3_vec, dtype=np.float32))
def thread_main(pin_cid3_vec_path,
cid3_vec_path,
output_path,
top_K,
batch_size,
n_compute_threads=5):
p = ProducerThread(pin_cid3_vec_path, batch_size)
threads = []
p.start()
threads.append(p)
for i in range(n_compute_threads):
compute_c = ComputeConsumerThread(cid3_vec_path, top_K)
compute_c.start()
threads.append(compute_c)
for i in range(1):
write_c = WriteConsumerThread(output_path, n_compute_threads, thread_id=i)
write_c.start()
threads.append(write_c)
start_time = time.time()
print("Compute Start: %s" % time.ctime(int(start_time)))
for thread in threads:
thread.join()
end_time = time.time()
print("Compute End: %s" % time.ctime(int(end_time)))
last_time = end_time - start_time
print("Last time: %.2f seconds" % (last_time / 1000))
if __name__ == '__main__':
if len(sys.argv) > 6:
pin_cid3_vec_path = sys.argv[1]
cid3_vec_path = sys.argv[2]
output_path = sys.argv[3]
K = int(sys.argv[4])
batch_size = int(sys.argv[5])
n_compute_threads = int(sys.argv[6])
else:
home_path = '../../data/'
pin_cid3_vec_path = home_path + 'pin_vec_by_cid3_vec'
cid3_vec_path = home_path + 'forU_cid3_vec'
output_path = home_path + 'forU_pin_cid3_ref'
K = 10
batch_size = 10000
n_compute_threads = 20
thread_main(pin_cid3_vec_path,
cid3_vec_path,
output_path,
K,
batch_size,
n_compute_threads)
'''
end_time = time.time()
print("Compute End: %s" % time.ctime(int(end_time)))
last_time = end_time - start_time
print("Last time: %.2f seconds" % (last_time / 1000))
'''
|