diff --git a/context-aware/library.py b/context-aware/library.py index d639771aa9960d10a82f019bb1df205b37a03d1d..3362ecdbaabe072601d30236044a68a8a409d5da 100644 --- a/context-aware/library.py +++ b/context-aware/library.py @@ -1,3 +1,5 @@ +from math import floor + def ca_set_binary_threshold_from_skew(input_dict): cost_false_pos = input_dict['cost_false_pos'] cost_false_neg = input_dict['cost_false_neg'] @@ -46,6 +48,7 @@ def ca_rate_driven_threshold_selection(input_dict): from collections import Counter performance = input_dict['score'] + rate = input_dict['rate'] list_score = [] labels = '' n = len(performance['actual']) @@ -53,9 +56,18 @@ def ca_rate_driven_threshold_selection(input_dict): list_score.append((performance['actual'][i],performance['predicted'][i])) output_dict = {} sorted_score = sorted(list_score, key=lambda scr: scr[1],reverse=True) - counter_neg = len([score for score in list_score if score[0] == 0]) - counter_pos = len([score for score in list_score if score[0] == 1]) - output_dict['bin_thres'] = find_best_roc_weight('rate',sorted_score,counter_pos,counter_neg) + + rank = floor(n * (float(rate) / float(100))) + current_rank = 0 + previous = float('inf') + current = previous + for i in range(n): + current = list_score[i][1] + current_rank = current_rank + 1 + if current_rank > rank: + output_dict['bin_thres'] = (previous + current) / float(2) + break + previous = list_score[i][1] return output_dict def ca_score_driven_threshold_selection(input_dict): @@ -86,8 +98,10 @@ def find_best_roc_weight(method,a_list,a_num_positives,a_num_negatives): current = the_roc[1] if current != previous: possible_best_value = get_value(method,xpos,xneg,a_num_positives,a_num_negatives) + print '%f > %f' %(possible_best_value,the_best_value) if possible_best_value > the_best_value: the_best_value = possible_best_value + print '%f -> %f' %(best,(previous + current) / float(2)) best = (previous + current) / float(2) if the_roc[0] == 1: xpos += 1 @@ -103,16 +117,18 @@ def find_best_roc_weight(method,a_list,a_num_positives,a_num_negatives): def get_value(method, TP, TN, P, N): if method == 'accuracy': - accuracy = ( TP + TN ) / float( N + P ) + accuracy = (TP + TN) / float(P+N) return accuracy elif method == 'balanced': balanced = ( TP / float(P) + TN / float(N)) / 2 return balanced FN = P - TP FP = N - TN - recall = TN / float(N) - if TN + FN > 0: - precision = TN / float(TN + FN) + recall = TP / float(P) + if method == 'recall': + return recall + if TP + FP > 0: + precision = TP / float(TP + P) if method == 'precision': return precision if precision + recall > 0: