前言
日常工作中,有一些代码的复用是非常高频的,每次用的时候都要搜一遍找到,很费时间,这里就整理一些日常常用的,快速copy,开始个人的debug吧!!
tensorflower
import tensorflow as tf
import numpy as np
# input_ids = tf.placeholder(dtype=tf.int32, shape=[None,4])
# input_len = tf.placeholder(dtype=tf.int32, shape=[None,4])
# embedding = tf.Variable(np.identity(5, dtype=np.int32))
# tag_embedding = tf.nn.embedding_lookup(embedding, input_ids)
# tag_embedding = tf.matmul(tag_embedding, tf.expand_dims(input_len, -1), transpose_a=True)
# tag_embedding = tf.squeeze(tag_embedding)
#
# tag_embedding_s = tf.divide(tag_embedding, tf.expand_dims(tf.reduce_sum(input_len, 1) ,-1))
#
# sess = tf.InteractiveSession()
# sess.run(tf.global_variables_initializer())
# print(embedding.eval())
# print("***")
#
#
# my_inputs_id = [[1],[2,3,1,0],[3,2]]
# tag_input = []
# tag_len = []
# max_tag_len = 4
# for cur_tag in my_inputs_id:
# cur_tag_len = len(cur_tag)
# pad_tag_len = max_tag_len - cur_tag_len
# tag_input.append(cur_tag + [0] * pad_tag_len)
# tag_len.append([1]*cur_tag_len + [0]*pad_tag_len)
# print(my_inputs_id)
# print(tag_len)
#
#
# print(sess.run(tag_embedding_s, feed_dict={input_ids: tag_input, input_len: tag_len}))
input_ids = tf.placeholder(dtype=tf.int32, shape=[None])
alpha_embedding = tf.Variable(tf.random_uniform((100, 3), -1, 1))
embed_list = [tf.Variable([[1.0,2.0,3.0,6,7],[4,5,6,7,7]]), tf.Variable([[1.1,2.1,3.1,7,7],[4,5,6,7,7]]), tf.Variable([[1.2,2.2,3.2,7,7],[4,5,6,7,7]])]
stack_embed = tf.stack(embed_list, axis=-1)
# attention merge
alpha_embed = tf.nn.embedding_lookup(alpha_embedding, input_ids)
alpha_embed_expand = tf.expand_dims(alpha_embed, 1)
alpha_i_sum = tf.reduce_sum(tf.exp(alpha_embed_expand), axis=-1)
ff = stack_embed * tf.exp(alpha_embed_expand)
merge_emb = tf.reduce_sum(ff, axis=-1) / alpha_i_sum
sess = tf.InteractiveSession()
sess.run(tf.global_variables_initializer())
result = sess.run(ff, feed_dict={input_ids:np.array([1,2])})
print(result.shape)
pyspark
from pyspark import SparkContext ,SparkConf, PickleSerializer
from pyspark.sql import Row
from pyspark.sql.session import SparkSession
from py4j.protocol import Py4JJavaError
class SparkContext_dirs(SparkContext):
def __init__(self, master=None, appName=None, sparkHome=None, pyFiles=None, environment=None, batchSize=0,
serializer=PickleSerializer(), conf=None, gateway=None, jsc=None):
SparkContext.__init__(self, master=master, appName=appName, sparkHome=sparkHome, pyFiles=pyFiles,
environment=environment, batchSize=batchSize, serializer=serializer, conf=conf,
gateway=gateway, jsc=jsc)
def text_dirs(self, dirs):
hadoopConf = {"mapreduce.input.fileinputformat.inputdir": ",".join(
dirs), "mapreduce.input.fileinputformat.input.dir.recursive": "true"}
pair = self.hadoopRDD(inputFormatClass="org.apache.hadoop.mapred.TextInputFormat",
keyClass="org.apache.hadoop.io.LongWritable", valueClass="org.apache.hadoop.io.Text",
conf=hadoopConf)
text = pair.map(lambda pair: pair[1])
return text
def get_filter(x):
#filter no read log
index = x.find("|")
if index != -1 and x[index+1][0]=="1":
return True
else:
return False
def get_field(x, data_str, time_str):
#get user, item, timestamp
index = x.find("|")
user, item = x[:index].split("_")[:2]
timestamp = "".join(data_str.split("-")) + time_str
value = item + "@" + timestamp
return (user, [value])
def get_keyvalue(x):
key, values = x.split()
key = key[2:-1]
values = values[3:-3]
return (key, [values])
def get_novel_data():
data_strs = ['2021-06-09','2021-06-08']
conf = SparkConf().setAppName("miniProject").setMaster("local[1]")
# conf=SparkConf().setAppName("lg").setMaster("spark://192.168.10.182:7077")
sc = SparkContext(conf=conf)
#python程序中没有默认的sparksession
lines = sc.textFile("./part-00007")
lines = lines.map(lambda x: get_keyvalue(x))
lines = lines.reduceByKey(lambda a, b: a + b)
lines = lines.filter(lambda x:len(x[1])>1)
'''
dirs = []
for data_str in data_strs:
new_file = "../data/book/%s"%data_str
dirs.append(new_file)
lines = sc.text_dirs(dirs)
cur_rdd_has_read = lines.filter(lambda x: get_filter(x))
cur_rdd_get_columns = cur_rdd_has_read.map(lambda x: get_field(x, "08", "01"))
'''
count = 0
for i in lines.collect():
print(i)
if count==4:
break
count = count + 1
if __name__ == "__main__":
get_novel_data()
|