converters.py 20.8 KB
Newer Older
1
2
3
4
5
'''
Classes for handling DBContexts for ILP systems.

@author: Anze Vavpetic <anze.vavpetic@ijs.si>
'''
6
import re
7

8
class Converter:
9
    '''
10
    Base class for converters.
11
    '''
12
    def __init__(self, dbcontext):
13
14
15
16
        self.db = dbcontext
        self.connection = dbcontext.connection.connect()
        self.cursor = self.connection.cursor()

17
    def __del__(self):  
18
19
        self.connection.close()

20
21
22
23
24
25
26
27
class ILP_Converter(Converter):
    '''
    Base class for converting between a given database context (selected tables, columns, etc)
    to inputs acceptable by a specific ILP system.

    If possible, all subclasses should use lazy selects by forwarding the DB connection.
    '''
    def __init__(self, *args, **kwargs):
28
29
        self.settings = kwargs.pop('settings', {}) if kwargs else {}
        self.discr_intervals = kwargs.pop('discr_intervals', {}) if kwargs else {}
30
        self.dump = kwargs.pop('dump', False) if kwargs else False
31
32
        Converter.__init__(self, *args, **kwargs)

Anze Vavpetic's avatar
Anze Vavpetic committed
33
34
    def user_settings(self):
        return [':- set(%s,%s).' % (key,val) for key, val in self.settings.items()]
35
36

    def mode(self, predicate, args, recall=1, head=False):
Anze Vavpetic's avatar
Anze Vavpetic committed
37
38
39
40
41
42
43
44
        return ':- mode%s(%s, %s(%s)).' % ('h' if head else 'b', str(recall), predicate, ','.join([t+arg for t,arg in args]))

    def db_connection(self):
        con = self.db.connection
        host, db, user, pwd = con.host, con.database, con.user, con.password
        return [':- use_module(library(myddas)).', \
                ':- db_open(mysql, \'%s\'/\'%s\', \'%s\', \'%s\').' % (host, db, user, pwd)] + \
               [':- db_import(%s, %s).' % (table, table) for table in self.db.tables]
45
46
47

    def connecting_clause(self, table, ref_table):
        var_table, var_ref_table = table.capitalize(), ref_table.capitalize()
48
49
50
51
52
53
54
        result=[]
        for pk,fk in self.db.connected[(table, ref_table)]:
            ref_pk = self.db.pkeys[ref_table]
            table_args, ref_table_args = [], []
            for col in self.db.cols[table]:
                if col == pk:
                    col = var_table
55
                elif col in fk:
56
57
58
59
60
                    col = var_ref_table
                table_args.append(col.capitalize())
            for col in self.db.cols[ref_table]:
                if col == ref_pk:
                    col = var_ref_table
61
                if col in fk:
62
63
64
65
66
67
68
                    col = var_table
                ref_table_args.append(col.capitalize())
            result.extend(['has_%s(%s, %s) :-' % (ref_table, var_table.capitalize(), var_ref_table.capitalize()),
                            '\t%s(%s),' % (table, ','.join(table_args)),
                            '\t%s(%s).' % (ref_table, ','.join(ref_table_args))])
        return result

69
70

    def attribute_clause(self, table, att):
71
        var_table, var_att, pk = table.capitalize(), att.capitalize(), self.db.pkeys[table]
72
73
74
75
76
77
78
79
80
81
        intervals = []
        if self.discr_intervals.has_key(table):
            intervals = self.discr_intervals[table].get(att, [])
            if intervals:
                var_att = 'Discrete_%s' % var_att
        values_goal = '\t%s(%s)%s' % (table, ','.join([arg.capitalize() if arg!=pk else var_table for arg in self.db.cols[table]]), ',' if intervals else '.')
        discretize_goals = []
        n_intervals = len(intervals)
        for i, value in enumerate(intervals):
            punct = '.' if i == n_intervals-1 else ';'
Matic Perovšek's avatar
Matic Perovšek committed
82
            print value
83
84
            if i == 0:
                # Condition: att =< value_i
Matic Perovšek's avatar
Matic Perovšek committed
85
86
                label = '=< %.2f' % value
                condition = '%s =< %.2f' % (att.capitalize(), value)
87
88
89
90
                discretize_goals.append('\t((%s = \'%s\', %s)%s' % (var_att, label, condition, punct))
            if i < n_intervals-1:
                # Condition: att in (value_i, value_i+1]
                value_next = intervals[i+1]
Matic Perovšek's avatar
Matic Perovšek committed
91
92
                label = '(%.2f, %.2f]' % (value, value_next)
                condition = '%s > %.2f, %s =< %.2f' % (att.capitalize(), value, att.capitalize(), value_next)
93
94
95
                discretize_goals.append('\t(%s = \'%s\', %s)%s' % (var_att, label, condition, punct))
            else:
                # Condition: att > value_i
Matic Perovšek's avatar
Matic Perovšek committed
96
97
                label = '> %.2f' % value
                condition = '%s > %.2f' % (att.capitalize(), value)
98
99
100
                discretize_goals.append('\t(%s = \'%s\', %s))%s' % (var_att, label, condition, punct))
        return ['%s_%s(%s, %s) :-' % (table, att, var_table, var_att),  
                values_goal] + discretize_goals
101

102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
    @staticmethod
    def numeric(val):
        for num_type in [int, float, long, complex]:
            try:
                num_type(val)
                return True
            except:
                pass
        return False

    def dump_tables(self):
        dump = []
        fmt_cols = lambda cols: ','.join([("%s" % col) if ILP_Converter.numeric(col) else ("'%s'" % col) for col in cols])
        for table in self.db.tables:
            attributes = self.db.cols[table]
            dump.append('\n'.join(["%s(%s)." % (table, fmt_cols(cols)) for cols in self.db.rows(table, attributes)]))
        return dump

120
class RSD_Converter(ILP_Converter):
Anze Vavpetic's avatar
Anze Vavpetic committed
121
122
123
124
125
    '''
    Converts the database context to RSD inputs.
    '''
    def all_examples(self):
        target = self.db.target_table
126
127
        examples = self.db.rows(target, [self.db.target_att, self.db.pkeys[target]])
        return '\n'.join(["%s('%s', %s)." % (target, cls, pk) for cls, pk in examples])
128

Anze Vavpetic's avatar
Anze Vavpetic committed
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
    def background_knowledge(self):
        modeslist, getters = [self.mode(self.db.target_table, [('+', self.db.target_table)], head=True)], []
        for (table, ref_table) in self.db.connected.keys():
            if ref_table == self.db.target_table:
                continue # Skip backward connections
            modeslist.append(self.mode('has_%s' % ref_table, [('+', table), ('-', ref_table)]))
            getters.extend(self.connecting_clause(table, ref_table))
        for table, atts in self.db.cols.items():
            for att in atts:
                if att == self.db.target_att and table == self.db.target_table or \
                   att in self.db.fkeys[table] or att == self.db.pkeys[table]:
                    continue
                modeslist.append(self.mode('%s_%s' % (table, att), [('+', table), ('-', att)]))
                modeslist.append(self.mode('instantiate', [('+', att)]))
                getters.extend(self.attribute_clause(table, att))
144
145
146
147
148
        if not self.dump:
            b = '\n'.join(self.db_connection() + modeslist + getters + self.user_settings())
        else:
            b = '\n'.join(modeslist + getters + self.user_settings() + self.dump_tables())
        return b
149

150
class Aleph_Converter(ILP_Converter):
151
152
153
    '''
    Converts the database context to Aleph inputs.
    '''
Anze Vavpetic's avatar
Anze Vavpetic committed
154
    def __init__(self, *args, **kwargs):
155
        self.target_att_val = kwargs.pop('target_att_val')
156
        ILP_Converter.__init__(self, *args, **kwargs)
Anze Vavpetic's avatar
Anze Vavpetic committed
157
        self.__pos_examples, self.__neg_examples = None, None
158
        self.target_predicate = re.sub('\s+', '_', self.target_att_val).lower()
Anze Vavpetic's avatar
Anze Vavpetic committed
159

160
161
162
    def __target_predicate(self):
        return 'target_%s' % self.target_predicate

Anze Vavpetic's avatar
Anze Vavpetic committed
163
164
    def __examples(self):
        if not (self.__pos_examples and self.__neg_examples):
165
166
            target, att, target_val = self.db.target_table, self.db.target_att, self.target_att_val
            rows = self.db.rows(target, [att, self.db.pkeys[target]])
Anze Vavpetic's avatar
Anze Vavpetic committed
167
168
            pos_rows, neg_rows = [], []
            for row in rows:
169
                if str(row[0]) == target_val:
Anze Vavpetic's avatar
Anze Vavpetic committed
170
171
172
                    pos_rows.append(row)
                else:
                    neg_rows.append(row)
173
174
175
176

            if not pos_rows:
                raise Exception('No positive examples with the given target attribute value, please re-check.')

177
178
            self.__pos_examples = '\n'.join(['%s(%s).' % (self.__target_predicate(), id) for _, id in pos_rows])
            self.__neg_examples = '\n'.join(['%s(%s).' % (self.__target_predicate(), id) for _, id in neg_rows])
Anze Vavpetic's avatar
Anze Vavpetic committed
179
180
        return self.__pos_examples, self.__neg_examples

181
    def positive_examples(self):
Anze Vavpetic's avatar
Anze Vavpetic committed
182
183
        return self.__examples()[0]

184
    def negative_examples(self):
Anze Vavpetic's avatar
Anze Vavpetic committed
185
186
        return self.__examples()[1]

187
    def background_knowledge(self):
188
        modeslist, getters = [self.mode(self.__target_predicate(), [('+', self.db.target_table)], head=True)], []
Anze Vavpetic's avatar
Anze Vavpetic committed
189
190
191
192
193
        determinations, types = [], []
        for (table, ref_table) in self.db.connected.keys():
            if ref_table == self.db.target_table:
                continue # Skip backward connections
            modeslist.append(self.mode('has_%s' % ref_table, [('+', table), ('-', ref_table)], recall='*'))
194
            determinations.append(':- determination(%s/1, has_%s/2).' % (self.__target_predicate(), ref_table))
Anze Vavpetic's avatar
Anze Vavpetic committed
195
196
197
198
199
200
201
202
203
            types.extend(self.concept_type_def(table))
            types.extend(self.concept_type_def(ref_table))
            getters.extend(self.connecting_clause(table, ref_table))
        for table, atts in self.db.cols.items():
            for att in atts:
                if att == self.db.target_att and table == self.db.target_table or \
                   att in self.db.fkeys[table] or att == self.db.pkeys[table]:
                    continue
                modeslist.append(self.mode('%s_%s' % (table, att), [('+', table), ('#', att)], recall='*'))
204
                determinations.append(':- determination(%s/1, %s_%s/2).' % (self.__target_predicate(), table, att))
Anze Vavpetic's avatar
Anze Vavpetic committed
205
206
                types.extend(self.constant_type_def(table, att))
                getters.extend(self.attribute_clause(table, att))
207
        local_copies = [self.local_copy(table) for table in self.db.tables]
208
209
210
211
212
        if not self.dump:
            b = '\n'.join(self.db_connection() + local_copies + self.user_settings() + modeslist + determinations + types + getters)
        else:
            b = '\n'.join(self.user_settings() + modeslist + determinations + types + getters + self.dump_tables())
        return b
Anze Vavpetic's avatar
Anze Vavpetic committed
213
214

    def concept_type_def(self, table):
215
216
217
218
        var_pk = self.db.pkeys[table].capitalize()
        variables = ','.join([var_pk if col.capitalize() == var_pk else '_' for col in self.db.cols[table]])
        return ['%s(%s) :-' % (table, var_pk), 
                '\t%s(%s).' % (table, variables)]
Anze Vavpetic's avatar
Anze Vavpetic committed
219
220

    def constant_type_def(self, table, att):
221
222
223
224
        var_att = att.capitalize()
        variables = ','.join([var_att if col == att else '_' for col in self.db.cols[table]])
        return ['%s(%s) :-' % (att, var_att), 
                '\t%s(%s).' % (table, variables)]
225

226
227
228
229
230
231
    def db_connection(self):
        con = self.db.connection
        host, db, user, pwd = con.host, con.database, con.user, con.password
        return [':- use_module(library(myddas)).', \
                ':- db_open(mysql, \'%s\'/\'%s\', \'%s\', \'%s\').' % (host, db, user, pwd)] + \
               [':- db_import(%s, tmp_%s).' % (table, table) for table in self.db.tables]
Anze Vavpetic's avatar
Anze Vavpetic committed
232

233
234
235
    def local_copy(self, table):
        cols = ','.join([col.capitalize() for col in self.db.cols[table]])
        return ':- repeat, tmp_%s(%s), (%s(%s), !, fail ; assertz(%s(%s)), fail).' % (table, cols, table, cols, table, cols)
236

237
238
239

class Orange_Converter(Converter):
    '''
Matic Perovšek's avatar
Matic Perovšek committed
240
    Converts the selected tables in the given context to orange example tables.
241
242
    '''
    continuous_types = ('FLOAT','DOUBLE','DECIMAL','NEWDECIMAL')
243
244
    integer_types = ('TINY','SHORT','LONG','LONGLONG','INT24')
    ordinal_types = ('YEAR','VARCHAR','SET','VAR_STRING','STRING','BIT')
245
246
247
    
    def __init__(self, *args, **kwargs):
        Converter.__init__(self, *args, **kwargs)
Matic Perovšek's avatar
Matic Perovšek committed
248
249
250
        self.types={}
        for table in self.db.tables:
            self.types[table]= self.db.fetch_types(table, self.db.cols[table])
251
        self.db.compute_col_vals()
252

Matic Perovšek's avatar
Matic Perovšek committed
253
254
255
256
257
258
259
260
261
262
    def target_Orange_table(self):
        table, cls_att = self.db.target_table, self.db.target_att
        return self.convert_table(table, cls_att)

    def other_Orange_tables(self):
        target_table = self.db.target_table

        return[ self.convert_table(table,None) for table in self.db.tables if table!=target_table]


263
    def convert_table(self, table_name, cls_att=None):
264
265
266
267
        '''
        Returns the target table as an orange example table.
        '''
        import orange
Matic Perovšek's avatar
Matic Perovšek committed
268
269

        cols = self.db.cols[table_name]
270
        attributes, metas, class_var = [], [], None
271
        for col in cols:
Matic Perovšek's avatar
Matic Perovšek committed
272
            att_type = self.orng_type(table_name,col)
273
            if att_type == 'd':
Matic Perovšek's avatar
Matic Perovšek committed
274
                att_vals = self.db.col_vals[table_name][col]
275
                att_var = orange.EnumVariable(str(col), values=[str(val) for val in att_vals])
276
277
278
279
280
281
            elif att_type == 'c':
                att_var = orange.FloatVariable(str(col))
            else:
                att_var = orange.StringVariable(str(col))
            if col == cls_att:
                if att_type == 'string':
Matic Perovšek's avatar
Matic Perovšek committed
282
                    raise Exception('Unsuitable data type for a target variable: %s' % att_type)
283
                class_var=att_var
284
                continue
285
            elif att_type == 'string' or table_name in self.db.pkeys and col in self.db.pkeys[table_name] or table_name in self.db.fkeys and col in self.db.fkeys[table_name]:
286
287
288
                metas.append(att_var)
            else:
                attributes.append(att_var)
289
        domain = orange.Domain(attributes, class_var)
290
291
292
        for meta in metas:
            domain.addmeta(orange.newmetaid(), meta)
        dataset = orange.ExampleTable(domain)
Matic Perovšek's avatar
Matic Perovšek committed
293
294
        dataset.name=table_name
        for row in self.db.rows(table_name, cols):
295
296
            example = orange.Example(domain)
            for col, val in zip(cols, row):
297
                example[str(col)] = str(val) if val!=None else '?'
298
299
300
            dataset.append(example)
        return dataset

301
    def orng_type(self, table_name, col):
302
303
304
        '''
        Assigns a given mysql column an orange type.
        '''
Matic Perovšek's avatar
Matic Perovšek committed
305
306
        mysql_type = self.types[table_name][col]
        n_vals = len(self.db.col_vals[table_name][col])
307
        if mysql_type in Orange_Converter.continuous_types or (n_vals >= 50 and mysql_type in Orange_Converter.integer_types):
308
            return 'c'
Matic Perovšek's avatar
Matic Perovšek committed
309
        elif mysql_type in Orange_Converter.ordinal_types+Orange_Converter.integer_types:
310
311
312
313
            return 'd'
        else:
            return 'string'

314
315
316
317
318
319
320

class TreeLikerConverter(Converter):
    '''
    Converts a db context to the TreeLiker dataset format.
    '''
    def __init__(self, *args, **kwargs):
        self.discr_intervals = kwargs.pop('discr_intervals', {}) if kwargs else {}
Anze Vavpetic's avatar
Anze Vavpetic committed
321
322
        self._template = []
        self._predicates = set()
323
324
        Converter.__init__(self, *args, **kwargs)

Anze Vavpetic's avatar
Anze Vavpetic committed
325

326
327
328
329
330
331
332
333
    def _row_pk(self, target, cols, row):
        row_pk = None
        for idx, col in enumerate(row):
            if cols[idx] == self.db.pkeys[target]:
                row_pk = col
                break
        return row_pk

Anze Vavpetic's avatar
Anze Vavpetic committed
334

335
    def _facts(self, pk, pk_att, target, visited=set()):
336
337
338
339
340
341
        '''
        Returns the facts for the given entity with pk in `table`.
        '''
        facts = []
        if target != self.db.target_table:
            cols = self.db.cols[target]
342
343
344

            # Skip the class attribute
            if self.db.target_att in cols:
345
346
                cols.remove(self.db.target_att)
            attributes = self.db.fmt_cols(cols)
347
348

            # All rows matching `pk`
349
350
351
            self.cursor.execute("SELECT %s FROM %s WHERE `%s`=%s" % (attributes, target, pk_att, pk))
            for row in self.cursor:
                values = []
352
353
354
355
                row_pk = self._row_pk(target, cols, row)
                row_pk_name = '%s%s' % (target, str(row_pk))

                # Each attr-value becomes one fact
356
357
                for idx, col in enumerate(row):
                    attr_name = cols[idx]
358
359

                    # We give pks/fks a symbolic name based on the table and id
360
                    if attr_name in self.db.fkeys[target]:
361
362
363
364
365
                        origin_table = self.db.reverse_fkeys[(target, attr_name)]
                        if origin_table != self.db.target_table:
                            col = '%s%s' % (origin_table, str(col))
                        else:
                            continue
366
                    elif attr_name == self.db.pkeys[target]:
Anze Vavpetic's avatar
Anze Vavpetic committed
367
368
369
370
371
372
373
                        predicate = 'has_%s' % target
                        facts.append('%s(%s)' % (predicate, row_pk_name))

                        if predicate not in self._predicates:
                            self._predicates.add(predicate)
                            self._template.append('%s(-%s)' % (predicate,
                                                               target))
374

375
376
                    # Constants
                    else:
Anze Vavpetic's avatar
Anze Vavpetic committed
377
                        predicate = 'has_%s' % attr_name
378
                        col = self._discretize_check(target, attr_name, col)
Anze Vavpetic's avatar
Anze Vavpetic committed
379
380
381
382
383
384
385
386
387
                        facts.append('%s(%s, %s)' % (predicate, 
                                                     row_pk_name,
                                                     str(col)))

                        if predicate not in self._predicates:
                            self._predicates.add(predicate)
                            self._template.append('%s(+%s, #%s)' % (predicate,
                                                                    target,
                                                                    attr_name))
388
389

        # Recursively follow links to other tables
390
391
392
        for table in self.db.tables:
            if (target, table) not in self.db.connected:
                continue
393

394
395
            for this_att, that_att in self.db.connected[(target, table)]:
                if table not in visited:
396
397
                    
                    # Link case 1: pk_att is a fk in another table
398
399
                    visited.add(target)
                    if this_att == pk_att:
400
                        facts.extend(self._facts(pk,
401
402
403
                                                  that_att,
                                                  table, 
                                                  visited=visited))
404
405
                    
                    # Link case 2: this_att is a fk of another table
406
407
408
409
410
                    else:
                        attributes = self.db.fmt_cols([this_att])
                        self.cursor.execute("SELECT %s FROM %s WHERE `%s`=%s" % (attributes, target, pk_att, pk))
                        fk_list = [row[0] for row in self.cursor]
                        for fk in fk_list:
411
                            facts.extend(self._facts(fk,
412
413
414
415
416
                                                      that_att,
                                                      table, 
                                                      visited=visited))
        return facts

Anze Vavpetic's avatar
Anze Vavpetic committed
417

418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
    def _discretize_check(self, table, att, col):
        '''
        Replaces the value with an appropriate interval symbol, if available.
        '''
        label = col
        if table in self.discr_intervals and att in self.discr_intervals[table]:
            intervals = self.discr_intervals[table][att]
            n_intervals = len(intervals)

            prev_value = None
            for i, value in enumerate(intervals):

                if i > 0:
                    prev_value = intervals[i-1]

                if not prev_value and col <= value:
                    label = "'=< %.2f'" % value
                    break
                elif prev_value and col <= value:
                    label = "'(%.2f, %.2f]'" % (prev_value, value)
                    break
                elif col > value and i == n_intervals - 1:
                    label = "'> %.2f'" % value
                    break

        return label


446
    def dataset(self):
447
448
449
450
        '''
        Returns the db context as a list of interpretations, i.e., a list of 
        facts true for each example.
        '''
451
452
        target = self.db.target_table
        db_examples = self.db.rows(target, [self.db.target_att, self.db.pkeys[target]])
453

454
455
        examples = []
        for cls, pk in sorted(db_examples, key=lambda ex: ex[0]):
456
            facts = self._facts(pk, self.db.pkeys[target], target)
457
458
459
460
            examples.append('%s %s' % (cls, ', '.join(facts)))

        return '\n'.join(examples)

Anze Vavpetic's avatar
Anze Vavpetic committed
461

462
    def default_template(self):
Anze Vavpetic's avatar
Anze Vavpetic committed
463
        return '[%s]' % (', '.join(self._template))
464

465

466
467
468
if __name__ == '__main__':
    from context import DBConnection, DBContext

469
    context = DBContext(DBConnection('ilp','ilp123','ged.ijs.si','trains'))
470
471
    context.target_table = 'trains'
    context.target_att = 'direction'
472
473
474
    # context = DBContext(DBConnection('ilp','ilp123','ged.ijs.si','muta_188'))
    # context.target_table = 'drugs'
    # context.target_att = 'active'
475
    intervals = {'cars': {'position' : [1, 3]}}
476
477
478
479
    #import cPickle
    #cPickle.dump(intervals, open('intervals.pkl','w'))
    #rsd = RSD_Converter(context, discr_intervals=intervals, dump=True)
    #aleph = Aleph_Converter(context, target_att_val='east', discr_intervals=intervals, dump=True)
480
    treeliker = TreeLikerConverter(context, discr_intervals=intervals)
481
482
483
484
485

    print treeliker.dataset()

    #print rsd.background_knowledge()
    #print aleph.background_knowledge()
486
    #orange = Orange_Converter(context)
Janez K's avatar
merge    
Janez K committed
487
    #orange.target_table()