Commit 719887a7 authored by Anze Vavpetic's avatar Anze Vavpetic
Browse files

adapted CV and discretization to work with DBContext; DBContext now has an...

adapted CV and discretization to work with DBContext; DBContext now has an internal representation with orange tables
parent 8106c2e2
......@@ -322,11 +322,18 @@ def cforange_discretize(input_dict):
import orange
from collections import defaultdict
input_tables = input_dict['dataset']
input_obj = input_dict['dataset']
output_tables=[]
input_type_is_list=type(input_tables) is list
if not input_type_is_list:
input_tables=[input_tables]
input_type = input_obj.__class__.__name__
if input_type == 'DBContext':
context = input_obj
input_tables = [context.orng_tables[tname] for tname in context.tables]
elif input_type != 'list':
input_tables = [input_obj]
else:
input_tables = input_obj
discretizerIndex = int(input_dict['discretizer_id'])
discretizers = [
......@@ -335,16 +342,13 @@ def cforange_discretize(input_dict):
("Entropy-based discretization", orange.EntropyDiscretization), #no arguments
("Bi-modal discretization", orange.BiModalDiscretization),#no arguments
("Fixed discretization", orange.EquiNDiscretization)#FixedDiscretization) #points
]
]
options={}
points=defaultdict(dict)
options = {}
points = defaultdict(dict)
if discretizerIndex in [4]:
#find all cut-off points
points = [float(a) for a in input_dict['points'].replace(" ","").split(",")]
#for k,v in input_dict.items():
# if k.startswith('points'):
# points.append(float(v))
options['points']=sorted(points)
elif discretizerIndex in [0,1]:
options['numberOfIntervals']=int(input_dict['numberOfIntervals'])
......@@ -355,35 +359,29 @@ def cforange_discretize(input_dict):
newattrs = []
for attr in inputdata.domain.attributes:
if attr.varType == orange.VarTypes.Continuous:
newattr=d(attr,inputdata) if discretizerIndex in [0,2,3] else d.constructVariable(attr)
newattr = d(attr,inputdata) if discretizerIndex in [0,2,3] else d.constructVariable(attr)
newattr.name=attr.name
#newattr.name=attr.name[2:] if newattr.name.startswith("D_"):
newattr.name = attr.name
newattrs.append(newattr)
points[inputdata.name][attr.name]=newattr.get_value_from.transformer.points
points[inputdata.name][attr.name] = newattr.get_value_from.transformer.points
else:
newattrs.append(attr)
name=inputdata.name
#for attr in newattrs: #TODO
# if attr.name.startswith("D_"):
# attr.name=attr.name[2:]
#new_t=inputdata.select(newattrs + [inputdata.domain.classVar])
name = inputdata.name
newdomain = orange.Domain(newattrs, inputdata.domain.classVar)
newdomain.addmetas(inputdata.domain.getmetas())
new_t = orange.ExampleTable(newdomain, inputdata)
new_t.name=name
new_t.name = name
output_tables.append(new_t)
#for attr in newattrs:
# print "%s: %s" % (attr.name, attr.values)
#interval4
#newclass = orange.EnumVariable("is versicolor", values = ["no", "yes"])
#newclass.getValueFrom = lambda ex, w: ex["iris"]=="Iris-versicolor"
#newdomain = orange.Domain(data.domain.attributes, newclass)
#data_v = orange.ExampleTable(newdomain, data)
if input_type == 'DBContext':
output = input_obj.copy()
output.orng_tables = dict(zip(input_obj.tables, output_tables))
elif input_type != 'list':
output = output_tables[0]
else:
output = output_tables
output_dict = {'odt': output_tables if input_type_is_list else output_tables[0],'discr_intervals':points} #returns list if input is list
output_dict = {'odt': output, 'discr_intervals': points}
return output_dict
def cforange_attribute_distance(input_dict):
......
......@@ -217,14 +217,33 @@ class Workflow(models.Model):
else:
input_seed = rand.randint(1, 100000000)
# Special case when reading from a DB
input_type = input_list.__class__.__name__
context = None
if input_type == 'DBContext':
context = input_list
input_list = context.orng_tables.get(context.target_table, None)
if not input_list:
raise Exception('CrossValidation: Empty input list!')
progress_total = len(input_list) # for progress bar
current_iteration = 0
#print(input_list, input_fold, input_seed);
# create folds
rand.seed(input_seed);
rand.seed(input_seed)
rand.shuffle(input_list)
folds = [input_list[i::input_fold] for i in range(input_fold)];
folds = []
if hasattr(input_list, "get_items_ref"):
# Orange table on input, so we cannot do slices
indices = range(len(input_list))
folds_indices = [indices[i::input_fold] for i in range(input_fold)]
folds = []
for fold_indices in folds_indices:
folds.append(input_list.get_items_ref(fold_indices))
else:
folds = [input_list[i::input_fold] for i in range(input_fold)]
# pass forward the seed
proper_output = fi.outputs.all()[2] # inner output
......@@ -235,21 +254,28 @@ class Workflow(models.Model):
for i in fo.inputs.all():
if not i.parameter:
if i.connections.count() > 0:
i.value = [];
i.save();
i.value = []
i.save()
for i in range(len(folds)):
#print(folds[i])
output_train = folds[:i] + folds[i+1:]
output_test = folds[i]
if input_type == 'DBContext':
output_train_obj = context.copy()
output_train_obj.orng_tables[context.target_table] = output_train
output_test_obj = context.copy()
output_test_obj.orng_tables[context.target_table] = output_test
output_train = output_train_obj
output_test = output_test_obj
""" Different parameters on which the widgets are going to be run"""
fi.unfinish() # resets widgets, (read all widgets.finished=false)
fo.unfinish() # resets widgets, (read all widgets.finished=false)
proper_output = fi.outputs.all()[0] # inner output
proper_output.value = folds[:i] + folds[i+1:];
#print(folds[:i] + folds[i+1:]);
proper_output.value = output_train
proper_output.save()
proper_output = fi.outputs.all()[1] # inner output
proper_output.value = folds[i]
#print(folds[i]);
proper_output.value = output_test
proper_output.save()
fi.finished=True # set the input widget as finished
fi.save()
......@@ -268,8 +294,6 @@ class Workflow(models.Model):
except:
raise
current_iteration = current_iteration+1
#input_list.value = []
def run(self):
if not USE_CONCURRENCY or not self.widget:
......@@ -555,7 +579,7 @@ class Widget(models.Model):
if self.type == 'regular' or self.type == 'subprocess':
""" if this is a subprocess or a regular widget than true."""
if not self.abstract_widget is None:
"""if this is an abstract widget than true; we save the widget function in a variable """
"""if this is an abstract widget than true we save the widget function in a variable """
function_to_call = getattr(workflows.library,self.abstract_widget.action)
input_dict = {}
outputs = {}
......
from collections import defaultdict
import pprint
import copy
from django import forms
import mysql.connector as sql
import converters
class DBConnection:
'''
......@@ -24,8 +28,9 @@ class DBConnection:
def connect(self):
return sql.connect(user=self.user, password=self.password, host=self.host, database=self.database)
class DBContext:
def __init__(self, connection, find_connections=False):
def __init__(self, connection, find_connections=False, in_memory=True):
'''
Initializes the fields:
tables: list of selected tables
......@@ -78,7 +83,6 @@ class DBContext:
self.fkeys[table].add(col)
self.reverse_fkeys[(table, col)] = ref_table
cursor.execute(
"SELECT table_name, column_name \
FROM information_schema.KEY_COLUMN_USAGE \
......@@ -89,6 +93,18 @@ class DBContext:
self.target_att = None
con.close()
self.orng_tables = None
if in_memory:
self.orng_tables = self.read_into_orange()
def read_into_orange(self):
conv = converters.Orange_Converter(self)
tables = {
self.target_table: conv.target_Orange_table()
}
tables.update(zip(self.tables[1:], conv.other_Orange_tables()))
return tables
def update(self, postdata):
'''
Updates the default selections with user's selections.
......@@ -116,13 +132,28 @@ class DBContext:
def fmt_cols(self, cols):
return ','.join(["`%s`" % col for col in cols])
def rows(self, table, cols):
def fetch(self, table, cols):
'''
Fetches rows from the db.
'''
con = self.connection.connect()
cursor = con.cursor()
cursor.execute("SELECT %s FROM %s" % (self.fmt_cols(cols), table))
con.close()
return [cols for cols in cursor]
def rows(self, table, cols):
'''
Fetches rows from the local cache or from the db if there's no cache.
'''
if self.orng_tables:
data = []
for ex in self.orng_tables[table]:
data.append([ex[col] for col in cols])
return data
else:
return self.fetch(table, cols)
def fetch_types(self, table, cols):
'''
Returns a dictionary of field types for the given table and columns.
......@@ -148,6 +179,9 @@ class DBContext:
self.col_vals[table][col] = [val for (_,val) in cursor]
con.close()
def copy(self):
return copy.copy(self)
def __repr__(self):
return pprint.pformat({
'target_table' : self.target_table,
......@@ -156,6 +190,7 @@ class DBContext:
'cols' : self.cols,
'connected' : self.connected,
'pkeys' : self.pkeys,
'fkeys' : self.fkeys
'fkeys' : self.fkeys,
'orng_tables': [(name, len(table)) for name, table in self.orng_tables.items()] if self.orng_tables else 'not in memory'
})
......@@ -36,13 +36,6 @@ class ILP_Converter(Converter):
def mode(self, predicate, args, recall=1, head=False):
return ':- mode%s(%s, %s(%s)).' % ('h' if head else 'b', str(recall), predicate, ','.join([t+arg for t,arg in args]))
def db_connection(self):
con = self.db.connection
host, db, user, pwd = con.host, con.database, con.user, con.password
return [':- use_module(library(myddas)).', \
':- db_open(mysql, \'%s\'/\'%s\', \'%s\', \'%s\').' % (host, db, user, pwd)] + \
[':- db_import(%s, %s).' % (table, table) for table in self.db.tables]
def connecting_clause(self, table, ref_table):
var_table, var_ref_table = table.capitalize(), ref_table.capitalize()
result=[]
......@@ -144,11 +137,8 @@ class RSD_Converter(ILP_Converter):
modeslist.append(self.mode('%s_%s' % (table, att), [('+', table), ('-', att)]))
modeslist.append(self.mode('instantiate', [('+', att)]))
getters.extend(self.attribute_clause(table, att))
if not self.dump:
b = '\n'.join(self.db_connection() + modeslist + getters + self.user_settings())
else:
b = '\n'.join(modeslist + getters + self.user_settings() + self.dump_tables())
return b
return '\n'.join(modeslist + getters + self.user_settings() + self.dump_tables())
class Aleph_Converter(ILP_Converter):
'''
......@@ -207,12 +197,7 @@ class Aleph_Converter(ILP_Converter):
determinations.append(':- determination(%s/1, %s_%s/2).' % (self.__target_predicate(), table, att))
types.extend(self.constant_type_def(table, att))
getters.extend(self.attribute_clause(table, att))
local_copies = [self.local_copy(table) for table in self.db.tables]
if not self.dump:
b = '\n'.join(self.db_connection() + local_copies + self.user_settings() + modeslist + determinations + types + getters)
else:
b = '\n'.join(self.user_settings() + modeslist + determinations + types + getters + self.dump_tables())
return b
return '\n'.join(self.user_settings() + modeslist + determinations + types + getters + self.dump_tables())
def concept_type_def(self, table):
var_pk = self.db.pkeys[table].capitalize()
......@@ -226,17 +211,6 @@ class Aleph_Converter(ILP_Converter):
return ['%s(%s) :-' % (att.lower(), var_att),
'\t%s(%s).' % (table, variables)]
def db_connection(self):
con = self.db.connection
host, db, user, pwd = con.host, con.database, con.user, con.password
return [':- use_module(library(myddas)).', \
':- db_open(mysql, \'%s\'/\'%s\', \'%s\', \'%s\').' % (host, db, user, pwd)] + \
[':- db_import(%s, tmp_%s).' % (table, table) for table in self.db.tables]
def local_copy(self, table):
cols = ','.join([col.capitalize() for col in self.db.cols[table]])
return ':- repeat, tmp_%s(%s), (%s(%s), !, fail ; assertz(%s(%s)), fail).' % (table, cols, table, cols, table, cols)
class Orange_Converter(Converter):
'''
......@@ -493,24 +467,10 @@ class TreeLikerConverter(Converter):
if __name__ == '__main__':
from context import DBConnection, DBContext
# context = DBContext(DBConnection('ilp','ilp123','ged.ijs.si','trains'))
# context.target_table = 'trains'
# context.target_att = 'direction'
context = DBContext(DBConnection('ilp','ilp123','ged.ijs.si','muta_42'))
context.target_table = 'drugs'
context.target_att = 'active'
# intervals = {'cars': {'position' : [1, 3]}}
#import cPickle
#cPickle.dump(intervals, open('intervals.pkl','w'))
#rsd = RSD_Converter(context, discr_intervals=intervals, dump=True)
#aleph = Aleph_Converter(context, target_att_val='east', discr_intervals=intervals, dump=True)
treeliker = TreeLikerConverter(context)
print treeliker.default_template()
print treeliker.dataset()
#print rsd.background_knowledge()
#print aleph.background_knowledge()
#orange = Orange_Converter(context)
#orange.target_table()
......@@ -1552,6 +1552,28 @@ $(function(){
});
$("#widgets a.crossvalidation").click(function() {
$.post(url['add-cv'], {'active_workflow' : activeCanvasId, 'scrollTop': activeCanvas.scrollTop(), 'scrollLeft':activeCanvas.scrollLeft()}, function(data) {
try {
jsonData = $.parseJSON(data)
if (jsonData.success==false) {
reportError(jsonData.message)
}
}
catch (err)
{
activeCanvas.append(data);
var outer_widget_id = $(data).find(".outer-widget-link").attr('rel');
var outer_widget_workflow_id = $(data).find(".outer-widget-workflow").attr('rel');
$("#widget"+outer_widget_id).remove();
refreshWidget(outer_widget_id,outer_widget_workflow_id);
updateWidgetListeners();
resizeWidgets();
}
},'html');
});
$("#widgets a.input").click(function() {
$.post(url['add-input'], {'active_workflow' : activeCanvasId, 'scrollTop': activeCanvas.scrollTop(), 'scrollLeft':activeCanvas.scrollLeft()}, function(data) {
try {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment