Commit ced17d79 authored by Anze Vavpetic's avatar Anze Vavpetic
Browse files

generating an interpretation for an example: each attribute becomes one fact;...

generating an interpretation for an example: each attribute becomes one fact; also added support for discretization intervals
parent 5c931ece
......@@ -81,19 +81,19 @@ class ILP_Converter(Converter):
punct = '.' if i == n_intervals-1 else ';'
if i == 0:
# Condition: att =< value_i
label = '=< %d' % value
condition = '%s =< %d' % (att.capitalize(), value)
label = '=< %.2%f' % value
condition = '%s =< %.2%f' % (att.capitalize(), value)
discretize_goals.append('\t((%s = \'%s\', %s)%s' % (var_att, label, condition, punct))
if i < n_intervals-1:
# Condition: att in (value_i, value_i+1]
value_next = intervals[i+1]
label = '(%d, %d]' % (value, value_next)
condition = '%s > %d, %s =< %d' % (att.capitalize(), value, att.capitalize(), value_next)
label = '(%.2%f, %.2%f]' % (value, value_next)
condition = '%s > %.2%f, %s =< %.2%f' % (att.capitalize(), value, att.capitalize(), value_next)
discretize_goals.append('\t(%s = \'%s\', %s)%s' % (var_att, label, condition, punct))
else:
# Condition: att > value_i
label = '> %d' % value
condition = '%s > %d' % (att.capitalize(), value)
label = '> %.2%f' % value
condition = '%s > %.2%f' % (att.capitalize(), value)
discretize_goals.append('\t(%s = \'%s\', %s))%s' % (var_att, label, condition, punct))
return ['%s_%s(%s, %s) :-' % (table, att, var_table, var_att),
values_goal] + discretize_goals
......@@ -315,63 +315,148 @@ class TreeLikerConverter(Converter):
self.discr_intervals = kwargs.pop('discr_intervals', {}) if kwargs else {}
Converter.__init__(self, *args, **kwargs)
def __facts(self, pk, pk_att, target, visited=set()):
def _row_pk(self, target, cols, row):
row_pk = None
for idx, col in enumerate(row):
if cols[idx] == self.db.pkeys[target]:
row_pk = col
break
return row_pk
def _facts(self, pk, pk_att, target, visited=set()):
'''
Returns the facts for the given entity with pk in `table`.
'''
facts = []
if target != self.db.target_table:
cols = self.db.cols[target]
if self.db.target_att in cols: # Skip the class attribute
# Skip the class attribute
if self.db.target_att in cols:
cols.remove(self.db.target_att)
attributes = self.db.fmt_cols(cols)
# All rows matching `pk`
self.cursor.execute("SELECT %s FROM %s WHERE `%s`=%s" % (attributes, target, pk_att, pk))
for row in self.cursor:
values = []
row_pk = self._row_pk(target, cols, row)
row_pk_name = '%s%s' % (target, str(row_pk))
# Each attr-value becomes one fact
for idx, col in enumerate(row):
attr_name = cols[idx]
# We give pks/fks a symbolic name based on the table and id
if attr_name in self.db.fkeys[target]:
col = '%s%s' % (self.db.reverse_fkeys[(target, attr_name)], str(col))
origin_table = self.db.reverse_fkeys[(target, attr_name)]
if origin_table != self.db.target_table:
col = '%s%s' % (origin_table, str(col))
else:
continue
elif attr_name == self.db.pkeys[target]:
col = '%s%s' % (target, str(col))
values.append(str(col))
facts.append('%s(%s)' % (target, ', '.join(values)))
facts.append('has_%s(%s)' % (target, row_pk_name))
# Constants
else:
col = self._discretize_check(target, attr_name, col)
facts.append('has_%s(%s, %s)' % (attr_name,
row_pk_name,
str(col)))
# Recursively follow links to other tables
for table in self.db.tables:
if (target, table) not in self.db.connected:
continue
for this_att, that_att in self.db.connected[(target, table)]:
if table not in visited:
# pk_att is a fk in another table
# Link case 1: pk_att is a fk in another table
visited.add(target)
if this_att == pk_att:
facts.extend(self.__facts(pk,
facts.extend(self._facts(pk,
that_att,
table,
visited=visited))
# this_att is a fk of another entity
# Link case 2: this_att is a fk of another table
else:
attributes = self.db.fmt_cols([this_att])
self.cursor.execute("SELECT %s FROM %s WHERE `%s`=%s" % (attributes, target, pk_att, pk))
fk_list = [row[0] for row in self.cursor]
for fk in fk_list:
facts.extend(self.__facts(fk,
facts.extend(self._facts(fk,
that_att,
table,
visited=visited))
return facts
def _discretize_check(self, table, att, col):
'''
Replaces the value with an appropriate interval symbol, if available.
'''
label = col
if table in self.discr_intervals and att in self.discr_intervals[table]:
intervals = self.discr_intervals[table][att]
n_intervals = len(intervals)
prev_value = None
for i, value in enumerate(intervals):
if i > 0:
prev_value = intervals[i-1]
if not prev_value and col <= value:
label = "'=< %.2f'" % value
break
elif prev_value and col <= value:
label = "'(%.2f, %.2f]'" % (prev_value, value)
break
elif col > value and i == n_intervals - 1:
label = "'> %.2f'" % value
break
return label
n_intervals = len(intervals)
for i, value in enumerate(intervals):
punct = '.' if i == n_intervals-1 else ';'
if i == 0:
# Condition: att =< value_i
label = '=< %.2%f' % value
condition = '%s =< %d' % (att.capitalize(), value)
discretize_goals.append('\t((%s = \'%s\', %s)%s' % (var_att, label, condition, punct))
if i < n_intervals-1:
# Condition: att in (value_i, value_i+1]
value_next = intervals[i+1]
label = '(%d, %d]' % (value, value_next)
condition = '%s > %d, %s =< %d' % (att.capitalize(), value, att.capitalize(), value_next)
discretize_goals.append('\t(%s = \'%s\', %s)%s' % (var_att, label, condition, punct))
else:
# Condition: att > value_i
label = '> %d' % value
condition = '%s > %d' % (att.capitalize(), value)
discretize_goals.append('\t(%s = \'%s\', %s))%s' % (var_att, label, condition, punct))
def dataset(self):
'''
Returns the db context as a list of interpretations, i.e., a list of
facts true for each example.
'''
target = self.db.target_table
db_examples = self.db.rows(target, [self.db.target_att, self.db.pkeys[target]])
examples = []
for cls, pk in sorted(db_examples, key=lambda ex: ex[0]):
facts = self.__facts(pk, self.db.pkeys[target], target)
facts = self._facts(pk, self.db.pkeys[target], target)
examples.append('%s %s' % (cls, ', '.join(facts)))
return '\n'.join(examples)
def default_template(self):
pass
if __name__ == '__main__':
from context import DBConnection, DBContext
......@@ -382,12 +467,12 @@ if __name__ == '__main__':
# context = DBContext(DBConnection('ilp','ilp123','ged.ijs.si','muta_188'))
# context.target_table = 'drugs'
# context.target_att = 'active'
#intervals = {'cars': {'position' : [1, 3]}}
intervals = {'cars': {'position' : [1, 3]}}
#import cPickle
#cPickle.dump(intervals, open('intervals.pkl','w'))
#rsd = RSD_Converter(context, discr_intervals=intervals, dump=True)
#aleph = Aleph_Converter(context, target_att_val='east', discr_intervals=intervals, dump=True)
treeliker = TreeLikerConverter(context, discr_intervals=[])
treeliker = TreeLikerConverter(context, discr_intervals=intervals)
print treeliker.dataset()
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment