Commit 3c2b428f authored by Marko Lalovic's avatar Marko Lalovic
Browse files

crossvalidation bugfix

parent 54650a0e
......@@ -236,10 +236,11 @@ class Workflow(models.Model):
if hasattr(input_list, "get_items_ref"):
import orange
# Orange table on input, so we cannot do slices
indices = orange.MakeRandomIndicesCV(input_list, randseed=input_seed)
indices = orange.MakeRandomIndicesCV(input_list, randseed=input_seed, folds=input_fold)
for i in range(input_fold):
output_train = input_list.select(indices, i, negate=1)
output_test = input_list.select(indices, i)
#print len(output_train), len(output_test)
folds.append((output_train, output_test))
else:
rand.seed(input_seed)
......@@ -261,8 +262,9 @@ class Workflow(models.Model):
for i in range(len(folds)):
#import pdb; pdb.set_trace()
if hasattr(input_list, "get_items_ref"):
output_test = folds[i][0]
output_train = folds[i][1]
output_test = folds[i][1]
output_train = folds[i][0]
print len(output_train), len(output_test)
else:
output_train = folds[:i] + folds[i+1:]
output_test = folds[i]
......@@ -279,9 +281,11 @@ class Workflow(models.Model):
fo.unfinish() # resets widgets, (read all widgets.finished=false)
proper_output = fi.outputs.all()[0] # inner output
proper_output.value = output_train
print len(output_train.orng_tables[context.target_table])
proper_output.save()
proper_output = fi.outputs.all()[1] # inner output
proper_output.value = output_test
print len(output_test.orng_tables[context.target_table])
proper_output.save()
fi.finished=True # set the input widget as finished
fi.save()
......
......@@ -149,8 +149,6 @@ class DBContext:
if self.orng_tables:
data = []
for ex in self.orng_tables[table]:
print cols
print self.orng_tables[table].domain
data.append([ex[str(col)] for col in cols])
return data
else:
......@@ -199,7 +197,7 @@ class DBContext:
con.close()
def copy(self):
return copy.copy(self)
return copy.deepcopy(self)
def __repr__(self):
return pprint.pformat({
......
......@@ -75,7 +75,6 @@ class ILP_Converter(Converter):
n_intervals = len(intervals)
for i, value in enumerate(intervals):
punct = '.' if i == n_intervals-1 else ';'
print value
if i == 0:
# Condition: att =< value_i
label = '=< %.2f' % value
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment