【Python】ナイーブベイズ分類器
学習プロセス
P(y_i|x)=\frac{P(x|y_i)P(y_i)}{P(x)}
ナイーブベイズは上の数式から確率を算出します。
今回は行うのはスパムメールの分類です。
出現する単語からスパムか否かを判断します。
例えば「無料」という単語がスパムメールの80%に含まれているとし、
スパムでないメールの20%にこの単語が含まれているとして、
スパムである確立を以下のように求めます。
\frac{0.8}{(0.8+0.2)}
だいたい80%です。
この計算をスパムメールに含まれている単語に対して行い確率を求めます。
データの用意
分類対象のデータはスパムのテストデータです。
以下のサイトからデータを取得できるようです。
https://spamassassin.apache.org/old/publiccorpus/
圧縮された状態で置かれていますがこの中に数百〜数千のメールを記載したファイルが格納されています。
このファイルから分類のために使う情報を抽出します。
ファイルの内容はだいたい以下の形式になっています。
From checkoutthesefreakhugecocksinapussy@framesetup.com Fri Aug 23 18:57:51 2002
Return-Path: <checkoutthesefreakhugecocksinapussy@framesetup.com>
Delivered-To: zzzz@localhost.spamassassin.taint.org
Received: from localhost (localhost [127.0.0.1])
by phobos.labs.spamassassin.taint.org (Postfix) with ESMTP id 0C12543F99
for <zzzz@localhost>; Fri, 23 Aug 2002 13:57:44 -0400 (EDT)
Received: from mail.webnote.net [193.120.211.219]
by localhost with POP3 (fetchmail-5.9.0)
for zzzz@localhost (single-drop); Fri, 23 Aug 2002 18:57:44 +0100 (IST)
Received: from framesetup.com (1Cust249.tnt15.det3.da.uu.net [67.217.14.249])
by webnote.net (8.9.3/8.9.3) with SMTP id SAA11649
for <zzzz@spamassassin.taint.org>; Fri, 23 Aug 2002 18:48:51 +0100
Message-Id: <200208231748.SAA11649@webnote.net>
From: "Free pussy" <checkoutthesefreakhugecocksinapussy@framesetup.com>
To: <zzzz@spamassassin.taint.org>
Subject: Free big cock in Pussy
Sender: "Free pussy" <checkoutthesefreakhugecocksinapussy@framesetup.com>
Mime-Version: 1.0
Date: Sat, 24 Aug 2002 13:46:11 -0400
Content-Type: text/html; charset="ISO-8859-1"
Content-Transfer-Encoding: 8bit
<!DOCTYPE HTML PUB...
今回は件名を分類に使うので「Subject:」の部分を抽出します。
実装
# coding: utf-8
import math
import glob
import re
import random
# 親クラス
class NaiveBayes:
def __init__(self):
self.probability = []
# 継承先で何かしら実装する
def preprocess(self, data):
pass
# 分類対象や前処理が違っていてもこの部分は共通だと想定
def predict(self,data):
true_prob = 0.0
false_prob = 0.0
if len(self.probability) == 0:
return 0.0
for p in self.probability:
if p['value'] in data:
true_prob += math.log(p['true_prob'])
false_prob += math.log(p['false_prob'])
else:
true_prob += math.log(1-p['true_prob'])
false_prob += math.log(1-p['false_prob'])
true_prob = math.exp(true_prob)
false_prob = math.exp(false_prob)
return true_prob / (true_prob+false_prob)
# スパムを分類するクラス
class SpamAssassinClassfier(NaiveBayes, object):
def __init__(self):
super(SpamAssassinClassfier,self).__init__()
# スパム、非スパムがそれぞれの単語を含む確率
def preprocess(self, data):
self.train_data = data
# 単語ごとにその単語が含まれるスパムと非スパムの数を数える
spam_count_data = {}
for d in data:
for word in d['words']:
if word not in spam_count_data.keys():
spam_count_data[word] = [0,0]
spam_count_data[word][d['flag']] += 1
# スムージング 0の確率を防ぐため
s = 0.5
# スパムの数
spam_num = len([1 for d in data if d['flag'] == 1])
print(spam_num)
non_spam_num = len(data) - spam_num
# スパムだった時、非スパムだった時単語が含まれる確率
for key,value in spam_count_data.items():
prob_spam = (s+value[1]) / (2*s+spam_num)
prob_no_spam = (s+value[0]) / (2*s+non_spam_num)
self.probability.append({'value': key, 'false_prob': prob_no_spam, 'true_prob': prob_spam})
def get_data(self, directory):
data = []
for d in directory:
for f in glob.glob(d[0]):
fp = open(f, "r")
content = fp.read()
fp.close()
result = re.search(r"Subject: (.*?)\n",content, re.DOTALL)
if result is not None and result.group(1) is not None:
message = result.group(1).lower()
data.append({'words': set(re.findall("[a-z']+",message)), 'flag': d[1]})
return data
前処理
get_dataメソッドで必要なデータを取得します。
- 引数… [スパムアサシンから取得したディレクトリ名,スパムかどうか(0or1)]
- 戻り値… [単語のリスト,スパムかどうか(0or1)]
具体的には以下の形式のデータです。
{'flag': 0, 'words': set(['swim', 'drunk', 'dr', "don't", 'paging', 'let', 'darwin', 'friends'])}
そして抽出したデータを使い、前処理に渡しスパムと非スパムがそれぞれの単語を含む確率を計算します。
preprocessメソッドに先ほど抽出したデータを渡すと、以下の形式の情報を含んだリストを生成します。
{'true_prob': 0.0029940119760479044, 'value': 'paperless', 'false_prob': 0.0001785076758300607}
予測
予測はpredictメソッドが行います。
このメソッドにメールから抽出した件名をリストにしてを渡します。
最終的に
[スパムが単語を含む確率] / ( [スパムが単語を含む確率]+[非スパムが単語を含む確率] )=スパムである確率
を計算します。
動作確認
実装したクラスを以下のように使い動かしてみます。
sac = SpamAssassinClassfier()
sac.preprocess(sac.get_data([["./train_easy_ham/*", 0], ["./train_hard_ham/*", 0], ["./train_spam/*", 1]]))
test_data = sac.get_data([["./test_easy_ham/*", 0], ["./test_hard_ham/*", 0], ["./test_spam/*", 1]])
for _ in range(1000):
index = random.randint(0,len(test_data)-1)
res = sac.predict(test_data[index]['words'])*100
print("{} result: {}".format("spam" if test_data[index]['flag'] == 1 else "no spam", res))
実行結果は以下のようになりました。
ある程度は予測できているのかと思います。
no spam result: 36.9468931639
no spam result: 9.9521377882e-08
spam result: 99.4858440503
spam result: 99.3350649141
no spam result: 3.112049255
no spam result: 7.53670490026
spam result: 99.965397391
no spam result: 0.00208135050203
no spam result: 1.45894522895
no spam result: 0.00404429192909
no spam result: 54.9840847768
no spam result: 11.6211268861
spam result: 99.9999185414
spam result: 99.758010449
no spam result: 1.01174732249
no spam result: 0.0323736493299
no spam result: 0.0394170019782
no spam result: 0.00158748007806
no spam result: 8.85842641012
no spam result: 3.28002811198
spam result: 93.6404115488
no spam result: 0.00404429192909
no spam result: 0.000756334519003
spam result: 12.447127353
spam result: 93.6903540025
no spam result: 3.01023751225
no spam result: 0.00122653779033
spam result: 99.9935261269
no spam result: 8.25183442606
no spam result: 6.13278704403
...
今回はスパムの判定を行うプログラムを実装しましたが、文章の分類に関して他にもいろいろできそうです。