#!/usr/bin/python
#
#Copyright (C) 2007 Roland Memisevic
#
#This program is distributed WITHOUT ANY WARRANTY; without even the implied 
#warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. 
"""
Feature extraction for spam filtering.

This script (very beta) performs simple feature extraction for 
classifying emails as either spam or no spam. The script looks for 
the mailbox-files "spam" and "nospam" in the current directory 
and extracts (mainly word-)features using an estimate of mutual 
information between word-occurences and the class-label.
In its current form, the script uses the module "saveloaddata" 
by Karthikesh Raju to dump the features. 
This, and other properties should be easily adaptable to personal 
preferences.
"""

import mailbox, email
import re
import saveloaddata
from numarray import *

numwords = 1000         #restrict attention to numwords words to restrict 
                        #the extraction duration
numwordfeatures = 200   #extract the numwordfeatures best words in terms of 
			#mutual information

spambox = mailbox.UnixMailbox(open("./spam"),email.message_from_file)
nospambox = mailbox.UnixMailbox(open("./nospam"),email.message_from_file)

isspam = []
allwords = {}
pdelim = re.compile(r"\W+")                 #use non-word chars as separators
pnum = re.compile(r"\d")                    #numerals present?
pthreecaps=re.compile(r"[A-Z][A-Z][A-Z]+")  #three or more capitalized letters in a row?

GOOD_CONTENT_RE = re.compile('(text/\w+|multipart/\w+)', re.I)
FILEEXT_RE = re.compile(r'(\.exe|\.zip|\.gif|\.jpg|\.jpeg)$', re.I)
def traverse(mail, stats):
    #get rid of some header junk:
    if len(stats.items())==0:
      stats['ctypes'] = []
      stats['containsfiles'] = False
      stats['nosubject'] = False
      stats['threecaps'] = False
      if not mail['subject']:
        stats['nosubject'] = True
      elif pthreecaps.search(mail['subject']):
        stats['threecaps'] = True
    ct = mail.get_content_type()
    stats['ctypes'].append(ct)
    fn = mail.get_filename()
    if fn and FILEEXT_RE.search(fn):
      stats['containsfiles'] = True
    if mail.is_multipart():
      payload = [ traverse(x,stats)[0] for x in mail.get_payload() ]
      mail.set_payload(payload)
    if not GOOD_CONTENT_RE.search(ct) or (fn and FILEEXT_RE.search(fn)):
      #replace attachment with some dummy string:
      mail.set_payload(" ") 
      for k, v in mail.get_params()[1:]:
        mail.del_param(k)
      mail.set_type('text/plain')
      del mail['Content-Transfer-Encoding']
      del mail['Content-Disposition']
    return mail, stats

def remove_duplicates(l):
  d = {}
  for x in l: d[x] = None
  return d.keys()

def mutual_information(p12, p1, p2):
  mi =    p12[0]*log(p12[0])-p12[0]*log(p1[0]*p2[0])
  mi = mi+p12[1]*log(p12[1])-p12[1]*log(p1[0]*p2[1])
  mi = mi+p12[2]*log(p12[2])-p12[2]*log(p1[1]*p2[0])
  mi = mi+p12[3]*log(p12[3])-p12[3]*log(p1[1]*p2[1])
  return mi

def addtodict(word, allwords):
  if allwords.has_key(word):
    allwords[word] = allwords[word]+1
  else:
    allwords[word] = 1

def keys_by_increasing_values(d):
  keys = []
  for k, v in d.items():
    keys.append((v, k))
  keys.sort()
  for i in range(len(keys)):
    keys[i] = keys[i][1]
  return keys

#-----------------------------------------------------

stopwords = []
f = open('stopwords.txt', 'r')
for sw in f:
  stopwords.append(sw.strip())
f.close()

spamwords = {}; nospamwords = {}; wordoccurs = {}
count = 0
statistics = []
for mb, spam in [(spambox,1), (nospambox,0)]:
  while(1):
    try:
      mail = mb.next()
      if not mail: break
      count = count+1
      print "Mail:", count, " Spam:", spam
      mail, stats = traverse(mail, {})
      statistics.append(stats)
      body = mail.as_string()
      del mail['Content-Type']
      isspam.append(spam)
      for word in pdelim.split(body):
	word = word.lower()
        if (len(word)<15) & (not pnum.search(word)) & (not word in stopwords):
          #addtodict(word,allwords)
	  if spam: addtodict(word,spamwords)
	  else: addtodict(word,nospamwords)
	  if not wordoccurs.has_key(word):
	    wordoccurs[word] = []
	  wordoccurs[word].append(count)
    except email.Errors.HeaderParseError:
      print 'headerparseerror...skipped'

isspam = array(isspam)
nummails = array(len(isspam),type=Float32)
commonwords = keys_by_increasing_values(spamwords)[-numwords:]
commonwords.extend(keys_by_increasing_values(nospamwords)[-numwords:])
commonwords = remove_duplicates(commonwords)
wordoccurence = {}

#compute wordoccurence frequencies
print 'computing word occurence frequencies...'
for w in commonwords:
  print w
  wordoccurence[w] = zeros(int(nummails))
  for i in range(int(nummails)):
    if i in wordoccurs[w]:
      wordoccurence[w][i] = 1
print 'done...'

#compute probabilities using '+1'-smoothing
print 'computing probabilities...'
pspam = (sum(array(isspam,type=Float32))+2)/(nummails+4)
pspam = array([1-pspam, pspam])
pwords = dict(); pjoints = dict()
for w in commonwords:
  pwords[w] = (sum(wordoccurence[w]) + 2)/(nummails+4)
  pwords[w] = array([1-pwords[w], pwords[w]])
  pjoints[w] = [[],[],[],[]]
  pjoints[w][0] = (sum(where((isspam==0)&(wordoccurence[w]==0),1.0,0.0))+1)/(nummails+4)
  pjoints[w][1] = (sum(where((isspam==0)&(wordoccurence[w]==1),1.0,0.0))+1)/(nummails+4)
  pjoints[w][2] = (sum(where((isspam==1)&(wordoccurence[w]==0),1.0,0.0))+1)/(nummails+4)
  pjoints[w][3] = (sum(where((isspam==1)&(wordoccurence[w]==1),1.0,0.0))+1)/(nummails+4)
print 'done...'

#compute the mutual information
print 'computing mutual information...'
mi = []
for w in commonwords:
  mi.append(mutual_information(pjoints[w],pspam,pwords[w]))
print '...done'

#extract features
words = []; words_sorted = []; words_and_mi = []; 
for ii in argsort(array(mi))[-numwordfeatures:]:
  words_sorted.append(commonwords[ii])
  words_and_mi.append(commonwords[ii] + ' ' + mi[ii].__str__())

#construct feature vectors:
print 'extracting features... '
features = zeros((len(words_sorted)+7,len(isspam)))
for ii in range(len(isspam)):
  jj = 0
  for w in words_sorted:
    if wordoccurence[w][ii]==1:
      features[jj,ii] = 1
    else: features[jj,ii] = 0 
    jj = jj + 1
  #extract some extra (nonword-)features
  features[jj,ii]  =int(statistics[ii]['containsfiles'])
  features[jj+1,ii]=int(len(statistics[ii]['ctypes']))==1
  features[jj+2,ii]=int(len(statistics[ii]['ctypes']))==2
  features[jj+3,ii]=int(len(statistics[ii]['ctypes']))==3
  features[jj+4,ii]=int(len(statistics[ii]['ctypes']))>3
  features[jj+5,ii]=int(statistics[ii]['nosubject'])
  features[jj+6,ii]=int(statistics[ii]['threecaps'])
print '...done'

#dump data:
DATA = {'features':features, 'class':isspam}
saveloaddata.dumpData(DATA, 'data')
f = open('words.txt', 'w')
f.write('\n'.join(words_and_mi))
f.close()
print 'Done.'

