Commit 222268d6 authored by Anze Vavpetic's avatar Anze Vavpetic
Browse files

added support for stratified x-validation, if the input is as an orange table or dbcontext

parent b7317ca3
......@@ -231,18 +231,19 @@ class Workflow(models.Model):
current_iteration = 0
# create folds
rand.seed(input_seed)
rand.shuffle(input_list)
folds = []
if hasattr(input_list, "get_items_ref"):
import orange
# Orange table on input, so we cannot do slices
indices = range(len(input_list))
folds_indices = [indices[i::input_fold] for i in range(input_fold)]
folds = []
for fold_indices in folds_indices:
folds.append(input_list.get_items_ref(fold_indices))
indices = orange.MakeRandomIndicesCV(input_list, randseed=input_seed)
for i in range(input_fold):
output_train = input_list.select(indices, i, negate=1)
output_test = input_list.select(indices, i)
folds.append((output_train, output_test))
else:
rand.seed(input_seed)
rand.shuffle(input_list)
folds = [input_list[i::input_fold] for i in range(input_fold)]
# pass forward the seed
......@@ -258,8 +259,13 @@ class Workflow(models.Model):
i.save()
for i in range(len(folds)):
output_train = folds[:i] + folds[i+1:]
output_test = folds[i]
#import pdb; pdb.set_trace()
if hasattr(input_list, "get_items_ref"):
output_test = folds[i][0]
output_train = folds[i][1]
else:
output_train = folds[:i] + folds[i+1:]
output_test = folds[i]
if input_type == 'DBContext':
output_train_obj = context.copy()
output_train_obj.orng_tables[context.target_table] = output_train
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment