目的
1.利用朴素贝叶斯分类器判断短信(数据集已给)是否为垃圾短信; 2.可参考给定的文本分类代码;不可以直接调用sklearn的方法;
数据集
经典的SMSSpamCollection.txt数据集,百度云数据集点击这里,提取码2333 其中spam表示垃圾短信,ham表示非垃圾短信
源代码
import re
import math
from sklearn import model_selection
def load_data():
with open('SMSS.txt', 'r', encoding='utf-8') as fr:
content = fr.readlines()
x = list()
y = list()
label = '\t'
for line in content:
result = line.split(label, maxsplit=2)
x.append(clean_data(result[1]))
y.append(1 if result[0]=='spam' else 0)
return x, y
def clean_data(origin_info):
temp_info = re.sub('\W', ' ', origin_info.lower())
words = re.split(r'\s+', temp_info)
return list(filter(lambda x: len(x) >= 3, words))
def build_word_set(x_train, y_train,x_test):
ham_count = 0
spam_count = 0
ham_words_count = 0
spam_words_count = 0
ham_words = list()
spam_words = list()
word_dictionary_set = set()
word_dictionary_size = 0
for words, y in zip(x_train, y_train):
if y == 0:
ham_count += 1
ham_words_count += len(words)
for word in words:
ham_words.append(word)
word_dictionary_set.add(word)
if y == 1:
spam_count += 1
spam_words_count += len(words)
for word in words:
spam_words.append(word)
word_dictionary_set.add(word)
word_dictionary_size = len(word_dictionary_set)
ham_map = dict()
spam_map = dict()
for word in ham_words:
ham_map[word] = ham_map.setdefault(word, 0) + 1
for word in spam_words:
spam_map[word] = spam_map.setdefault(word, 0) + 1
ham_probability = 0
spam_probability = 0
ham_probability = ham_count / (ham_count + spam_count)
spam_probability = spam_count / (ham_count + spam_count)
ham_pro = 0
spam_pro = 0
for word in x_test:
ham_pro += math.log(
(ham_map.get(word, 0) + 1) / (ham_count + word_dictionary_size))
spam_pro += math.log(
(spam_map.get(word, 0) + 1) / (spam_count + word_dictionary_size))
ham_pro += math.log(ham_probability)
spam_pro += math.log(spam_probability)
print('垃圾短信概率:', spam_pro)
print('非垃圾短信概率:', ham_pro)
if spam_pro >= ham_pro:
return '垃圾短信'
else:
return '非垃圾短信'
if __name__ == '__main__':
x,y = load_data()
x_train, x_test, y_train, y_test = model_selection.train_test_split(x,
y, test_size=0.2)
text=x_test[0]
print(text)
result = build_word_set(x_train, y_train,text)
print(result)
结果
随机输出一条测试集中的短信,判断该短信是否为垃圾短信
|