SD_learner_classifier.py 15 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
import sys
import orange
from SDRule import *
from Beam_SD import *
from Beam_SD_preprocessed import *
from Apriori_SD import *
from CN2_SD import *
from xmlMaker import *
import cStringIO



class SD_learner(orange.Learner):
    #static variables
    def_name = 'SD classifier'
    def_alg = 'SD'
    def_minSupport = 0.1
    def_minConfidence = 0.7
    def_beamWidth = 5
    def_g = 1.0
    def_k = 4
    def_maxRules = 0

    def __new__(cls, examples=None, **kwds):

        learner = orange.Learner.__new__(cls, **kwds)
        if examples:
            return learner(examples)
        else:
            return learner

    def __init__(self, name = def_name, algorithm = def_alg, minSupport = def_minSupport, \
                 minConfidence = def_minConfidence, beamWidth = def_beamWidth, g = def_g, \
                 k = def_k, max_rules = def_maxRules):

        # parameter checking
        if algorithm not in ["SD","SD-Preprocess","Apriori-SD","CN2-SD"]:
            raise Exception('unknown algorithm %s.' % algorithm)
        if type(minSupport) is not float or (minSupport > 1.0) or (minSupport <= 0.0):
            raise  ValueError('minSupport should be a float in the (0,1] range.')
        if type(minConfidence) is not float or (minConfidence > 1.0) or (minConfidence <= 0.0):
            raise  ValueError('minConfidence should be a float in the (0,1] range.')
        if type(beamWidth) is not int or (beamWidth <= 0) or (beamWidth > 1000):
            raise  ValueError('beamWidth should be an int in the (0,1000] range.')
        if type(g) not in [int, float] or (g < 0):
            raise  ValueError('g should be a non-negative int or float.')
        if type(k) is not int or (k < 0):
            raise  ValueError('k should be a non-negative int.')
        if type(max_rules) is not int or (max_rules < 0):
            raise  ValueError('max_rules should be a non-negative int.')

        self.name = name
        self.max_rules = max_rules
        self.algorithm = algorithm
        if algorithm == "SD":
            self.learner = Beam_SD(  minSupport , beamWidth , g )
        elif algorithm == "SD-Preprocess":
            self.learner = Beam_SD_preprocessed(  minSupport , beamWidth , g )
        elif algorithm == "Apriori-SD":
            self.learner = Apriori_SD(minSupport , minConfidence , k)
        elif algorithm == "CN2-SD":
            self.learner = CN2_SD( k )
        else:
            raise Exception('No such algorithm %s' % algorithm)

    def __call__(self, learndata, testdata = None, weight = None):
        # because of preprocessing
        if testdata:
            classifier = SD_Classifier(testdata)
        else:
            classifier = SD_Classifier(learndata)

        for targetClassValue in learndata.domain.classVar.values:
            targetClass = orange.Value(learndata.domain.classVar, targetClassValue)
            beam = self.learner (learndata, targetClass, self.max_rules)
            classifier.addRulesForClass(beam, targetClass)

        classifier.name = self.name
        classifier.algorithm = self.algorithm
        return classifier

#________________________________________________________________________________________

class SD_Classifier(orange.Classifier):
85
86
87
88
89
90
    def __init__(self, data=None):
        if data:
            if type(data) is not orange.ExampleTable:
                raise TypeError('Data is not an orange.ExampleTable')
            if data.domain.classVar.varType != orange.VarTypes.Discrete:
                raise TypeError('Data should have a discrete target variable.')
91
92

        self.data = data
93
        self.majorityClassifier = data and orange.MajorityLearner(self.data)
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
        self.rulesClass = []             # list of istances SDRules
        self.algorithm = "Subgroup discovery algorithm"

    def __call__(self, example, resultType = orange.GetValue):
        # 1. calculate sum of distributions of examples that cover the example
        num_cover = 0.0
        distribution = [0]* len(self.data.domain.classVar.values)
        for rsc in self.rulesClass:
            for rule in rsc.rules.rules:
                if rule.covers(example):
                    num_cover += 1
                    tmp_dist = rule(example, orange.GetProbabilities)
                    for i in range(len(distribution)):
                        distribution[i] += tmp_dist[i]
        # 2. calculate average of distributions of rules that cover example
        if num_cover != 0:
            max_index = 0
            for i in range(len(distribution)):
                distribution[i] = distribution[i]/num_cover
                if distribution[i] > distribution[max_index]:
                    max_index = i
            dist = orange.DiscDistribution(distribution)
            value = orange.Value(self.data.domain.classVar ,self.data.domain.classVar.values[max_index])
        # if no rule fiers
        else:
            value,dist = self.majorityClassifier(example, orange.GetBoth)

        # 3. -----------return
        if resultType == orange.GetValue :
            return value
        elif resultType == orange.GetBoth :
            return (value, dist)
        else :
            return dist

    def addRulesForClass(self, listOfRules, targetClass):
        targetClassRule = SDRule(self.data, targetClass, conditions=[], g =1)
        tmp = SDRules(listOfRules, targetClassRule )
        self.rulesClass.append( tmp)

    def printAll(self):
        rulesList = []
        for rClass in self.rulesClass:
            for rule in rClass.rules.rules:
                rulesList.append(rule.printRule())

        strobj = cStringIO.StringIO()
        strobj.writelines(rulesList)
        rules = strobj.getvalue()
        strobj.close()
        return rules


    def getRules(self, targetClass):
        for rsc in self.rulesClass:
            if rsc.targetClassRule.targetClass == targetClass:
                return rsc.rules

    # The parameter must be a writable object type that supports at least write(string) function
    # e.g. file, StringIO.StringIO, or a custom object type with the write(string) function
    # The resulting object must be closed manually by calling close()
    #
    def toPMML(self, outputObjType=cStringIO.StringIO):
        ''' Ouutput the ruleset in a PMML RuleSet XML schema. '''
        output = outputObjType
        myXML = XMLCreator()
        myXML.DOMTreeTop = dom.DOMImplementation().createDocument('', "PMML", None)
        myXML.DOMTreeRoot = myXML.DOMTreeTop.documentElement
        myXML.DOMTreeRoot.setAttribute("version", "3.2")
        myXML.DOMTreeRoot.setAttribute("xmlns", "http://www.dmg.org/PMML-3_2")
        myXML.DOMTreeRoot.setAttribute("xmlns:xsi", "http://www.w3.org/2001/XMLSchema-instance")

#Header
        header = myXML.insertNewNode(myXML.DOMTreeRoot, "Header")
        header.setAttribute("copyright", "MyCopyright")

        application = myXML.insertNewNode(header, "Application")
        application.setAttribute("name", "Subgroup discovery toolbox for Orange")
        application.setAttribute("version", "1.0")

#DataDictionary
        dataDictionary = myXML.insertNewNode(myXML.DOMTreeRoot, "DataDictionary")
        dataDictionary.setAttribute("numberOfFields", "%d"%len(self.data.domain.variables))
        for var in self.data.domain.variables:
            dataField = myXML.insertNewNode(dataDictionary, "DataField")
            dataField.setAttribute("name", var.name)
            dataField.setAttribute("displayName", var.name)
            dataField.setAttribute("optype", ["","categorical","continuous","other"][var.varType])
            dataField.setAttribute("dataType", ["","string","double",""][var.varType])
            if var.varType==1:
                for val in var.values:
                    value = myXML.insertNewNode(dataField, "Value")
                    value.setAttribute("value", val)
                    value.setAttribute("property", "valid")

#RuleSetModel
        ruleSetModel = myXML.insertNewNode(myXML.DOMTreeRoot, "RuleSetModel")
        ruleSetModel.setAttribute("modelName", "SubgroupDiscoveryRules") # spremeni v dinamicno
        ruleSetModel.setAttribute("functionName", "classification")
        ruleSetModel.setAttribute("algorithmName", self.algorithm)
  # MiningSchema
        miningSchema = myXML.insertNewNode(ruleSetModel, "MiningSchema")
        miningField = myXML.insertNewNode(miningSchema, "MiningField")
        miningField.setAttribute("name", self.data.domain.classVar.name)  # the target variable
        miningField.setAttribute("usageType", "predicted")
        for attr in self.data.domain.attributes:                          # the attributes
            miningField = myXML.insertNewNode(miningSchema, "MiningField")
            miningField.setAttribute("name", attr.name)
            miningField.setAttribute("usageType", "active")

  #RuleSet
        # default rule
        defVal = self.majorityClassifier.defaultValue.value
        defconf = self.majorityClassifier.defaultDistribution [self.majorityClassifier.defaultVal]
        defnbCorrect = defconf * len(self.data)
        ruleSet = myXML.insertNewNode(ruleSetModel, "RuleSet")
        ruleSet.setAttribute("defaultScore", defVal)
        ruleSet.setAttribute("recordCount", "%d"%len(self.data))
        ruleSet.setAttribute("nbCorrect", "%0.0f"%defnbCorrect)
        ruleSet.setAttribute("defaultConfidence", "%0.2f"%defconf)

        ruleSelectionMethod = myXML.insertNewNode(ruleSet, "RuleSelectionMethod")
        ruleSelectionMethod.setAttribute("criterion", "weightedSum")
        #rules
        for rsc in self.rulesClass:                          # for each class
            for i in range(len(rsc.rules.rules)):            # for each rule
                rule = rsc.rules.rules[i]
                simpleRule = myXML.insertNewNode(ruleSet, "SimpleRule")
                simpleRule.setAttribute("id", "%s%d"%(rsc.targetClassRule.targetClass,i+1))
                simpleRule.setAttribute("score", rule.targetClass.value)
                simpleRule.setAttribute("recordCount", "%d"%len(rule.examples))
                simpleRule.setAttribute("nbCorrect", "%d"%len(rule.TP))
                simpleRule.setAttribute("confidence", "%2.2f"%rule.confidence)
                simpleRule.setAttribute("weight", "1")
                if len (rule.filter.conditions )>1  :
                    compoundPredicate = myXML.insertNewNode(simpleRule, "CompoundPredicate")
                    compoundPredicate.setAttribute("booleanOperator", "and")
                elif len (rule.filter.conditions )==1 and \
                     (rule.data.domain[rule.filter.conditions[0].position].varType == orange.VarTypes.Continuous) and \
                      rule.filter.conditions[0].min != float(-infinity):
                    #if there is only one continuous condition in the filter with a range interval
                    compoundPredicate = myXML.insertNewNode(simpleRule, "CompoundPredicate")
                    compoundPredicate.setAttribute("booleanOperator", "and")
                else:
                    compoundPredicate = simpleRule
                for i,c in enumerate(rule.filter.conditions):
                    simplePredicate = myXML.insertNewNode(compoundPredicate, "SimplePredicate")
                    simplePredicate.setAttribute("field", rule.data.domain[c.position].name)
                    if rule.data.domain[c.position].varType == orange.VarTypes.Discrete:
                        simplePredicate.setAttribute("operator", "equal")
                        simplePredicate.setAttribute("value",str(rule.data.domain[c.position].values[int(c.values[0])]))

                    elif rule.data.domain[c.position].varType == orange.VarTypes.Continuous:
                        if c.min == float(-infinity):
                            if not c.outside:
                                simplePredicate.setAttribute("operator", "lessOrEqual") # <=
                            else:
                                simplePredicate.setAttribute("operator", "greaterThan") # >
                            simplePredicate.setAttribute("value","%.3f" % c.max)
                        else:    #interval gets transformed into two simple predicates, one <= and one >
                            simplePredicate.setAttribute("operator", "greaterThan") #this causes problems if there is only one condition with an interval in one rule, sice the "compoundRule" is not present
                            simplePredicate.setAttribute("value","%.3f" % c.min)
                            simplePredicate = myXML.insertNewNode(compoundPredicate, "SimplePredicate")
                            simplePredicate.setAttribute("field", rule.data.domain[c.position].name)
                            simplePredicate.setAttribute("operator", "lessOrEqual")
                            simplePredicate.setAttribute("value","%.3f" % c.max)
                for i in range(len(rule.targetClass.variable.values)):
                    scoreDistribution = myXML.insertNewNode(simpleRule, "ScoreDistribution")
                    scoreDistribution.setAttribute("value", rule.targetClass.variable.values[i])
                    scoreDistribution.setAttribute("recordCount", "%0.0f"%(rule.classDistribution[i]*len(rule.examples)))

        try:
            #PrettyPrint(myXML.DOMTreeTop, output)
            output.write(myXML.DOMTreeTop.toprettyxml())
            return output
        except Exception, e:
            print 'Error while outputting rules in PMML: ' + str(e)
            return None

    #end toPMML()


#___________________________________________________________________________________
if __name__=="__main__":

279
280
281
282
    # filename = "..\\..\\doc\\datasets\\lenses.tab"
    # if 'linux' in sys.platform:
    #     filename= "/usr/doc/orange/datasets/lenses.tab"
    data = orange.ExampleTable('lenses')
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301


    learner2 = SD_learner(algorithm = "Apriori-SD", minSupport = 0.1, minConfidence= 0.6)
    classifier2 = learner2(data)
    classifier2.printAll()

    learner3 = SD_learner(algorithm = "CN2-SD", k=3)
    classifier3 = learner2(data)
    classifier3.printAll()

    learner4 = SD_learner(algorithm = "SD-Preprocess", minSupport = 0.1, beamWidth = 5, g = 1)
    classifier4 = learner4(data)
    classifier4.printAll()

    print "___________________________"
    for d in data:
        print d.getclass(), classifier2(d, orange.GetValue), classifier3(d, orange.GetValue), classifier4(d, orange.GetValue)


302
303
304
305
306
307
308
309
310
    import cPickle
    one= classifier2.rulesClass[0].rules.rules[0]
    
    for obj in dir(one):
        try:
            cPickle.dump(getattr(one, obj), open('foo.pkl','w'))
            print obj, 'ok'
        except Exception, e:
            print obj, str(e)
311
312
313
314
315
316
317
318
319
320
321
322
323


    print "\n\n---> PMML model <---"
    result = classifier2.toPMML();
    if isinstance(result, file):
        print 'Result in file ', result.name
    elif isinstance(result, cStringIO.OutputType):
        print result.getvalue()
    else:
        raise TypeError
    result.close()