...
 
Commits (3)
This diff is collapsed.
...@@ -23,13 +23,9 @@ import jcl.clustering.constraints.CannotLinkConstraint; ...@@ -23,13 +23,9 @@ import jcl.clustering.constraints.CannotLinkConstraint;
import jcl.clustering.constraints.Constraint; import jcl.clustering.constraints.Constraint;
import jcl.clustering.constraints.MustLinkConstraint; import jcl.clustering.constraints.MustLinkConstraint;
import jcl.data.Data; import jcl.data.Data;
import jcl.data.Model; import jcl.data.distance.DistanceModel;
import jcl.data.attribute.AttributeMultiDimSequence;
import jcl.data.distance.Distance;
import jcl.data.distance.DistanceParameter; import jcl.data.distance.DistanceParameter;
import jcl.data.distance.MetaDistance; import jcl.data.distance.average.AverageParameter;
import jcl.data.distance.MetaDistanceEuclidean;
import jcl.data.distance.sequential.ParameterDTW;
import jcl.data.mask.IntArrayMask; import jcl.data.mask.IntArrayMask;
import jcl.data.mask.Mask; import jcl.data.mask.Mask;
import jcl.io.results.CSVResultWriter; import jcl.io.results.CSVResultWriter;
...@@ -37,7 +33,6 @@ import jcl.learning.methods.monostrategy.kmeans.ParametersKmeans; ...@@ -37,7 +33,6 @@ import jcl.learning.methods.monostrategy.kmeans.ParametersKmeans;
import jcl.learning.methods.multistrategy.samarah.HybridClassification; import jcl.learning.methods.multistrategy.samarah.HybridClassification;
import jcl.learning.methods.multistrategy.samarah.SamarahConfig; import jcl.learning.methods.multistrategy.samarah.SamarahConfig;
import jcl.utils.RandomizeTools; import jcl.utils.RandomizeTools;
import jcl.weights.ClassificationWeights;
import jcl.weights.GlobalWeights; import jcl.weights.GlobalWeights;
import multiCube.tools.image.ImageHelper; import multiCube.tools.image.ImageHelper;
import mustic.gui.ClassificationFrame; import mustic.gui.ClassificationFrame;
...@@ -45,7 +40,7 @@ import mustic.utils.io.CSVUtils; ...@@ -45,7 +40,7 @@ import mustic.utils.io.CSVUtils;
public class TestA2CNESIterative { public class TestA2CNESIterative {
public static void main(String[] args) { public static void main(String[] args) {
HybridClassification classification = new HybridClassification(); HybridClassification classification = new HybridClassification(null, null);
final String datasetName = "FacesUCR"; final String datasetName = "FacesUCR";
final String datasetPath = "FacesUCR"; final String datasetPath = "FacesUCR";
...@@ -53,10 +48,13 @@ public class TestA2CNESIterative { ...@@ -53,10 +48,13 @@ public class TestA2CNESIterative {
final String resultPath = System.getProperty("user.home")+"/A2CNES/results_iter/"; final String resultPath = System.getProperty("user.home")+"/A2CNES/results_iter/";
final int nInf = 12; final int nInf = 12;
final int nSup = 17; final int nSup = 17;
final int ag1_seeds = 16; final int[] ag_seeds = {16, 20, 24};
final int ag2_seeds = 20;
final int ag3_seeds = 24;
final int nb_iter = 15; final int nb_iter = 15;
File directory = new File("log");
if (!directory.exists()){
directory.mkdir();
}
// String testResultPath = System.getProperty("user.home")+"/A2CNES/Train_results/"; // String testResultPath = System.getProperty("user.home")+"/A2CNES/Train_results/";
...@@ -64,9 +62,6 @@ public class TestA2CNESIterative { ...@@ -64,9 +62,6 @@ public class TestA2CNESIterative {
final Data dataTest = TestA2CNES.getDataFromFile(dataPath+datasetPath+"/test/"+datasetName+".data", '\t', "test", null); final Data dataTest = TestA2CNES.getDataFromFile(dataPath+datasetPath+"/test/"+datasetName+".data", '\t', "test", null);
// AttributeMultiDimSequence.setMode(AttributeMultiDimSequence.EUCLIDIEN);
AttributeMultiDimSequence.setMode(AttributeMultiDimSequence.DTW_BARYCENTRE);
final DateTime startTime = DateTime.now(); final DateTime startTime = DateTime.now();
...@@ -94,7 +89,13 @@ public class TestA2CNESIterative { ...@@ -94,7 +89,13 @@ public class TestA2CNESIterative {
classification.setAdvancedParameters(degradation, classRatio, solutionType, kIntern, classification.setAdvancedParameters(degradation, classRatio, solutionType, kIntern,
kExtern, unificationType, criterion, constraintsWgt); kExtern, unificationType, criterion, constraintsWgt);
ClassificationWeights weights = new GlobalWeights(dataTest.getOneDataObject()); final DistanceModel distanceModel = DistanceModel.generateNaiveModel(dataTest.getOneDataObject(),
new GlobalWeights(dataTest.getOneDataObject()));
final DistanceParameter[][] distanceParameters = DistanceModel.generateDefaultDistanceParameters(
3, distanceModel, dataTest);
AverageParameter[] averageParameters = DistanceModel.generateDefaultAverageParameters(
distanceModel, dataTest);
final Vector<Thread> threadList = new Vector<Thread>(); final Vector<Thread> threadList = new Vector<Thread>();
final Vector<Classification> classifList = new Vector<Classification>(); final Vector<Classification> classifList = new Vector<Classification>();
...@@ -206,9 +207,11 @@ public class TestA2CNESIterative { ...@@ -206,9 +207,11 @@ public class TestA2CNESIterative {
// extractAndAddConstraints(subset, constraints, subsetSize, null); // extractAndAddConstraints(subset, constraints, subsetSize, null);
// currentData.updateAndSetConstraintsToSample(subset); // currentData.updateAndSetConstraintsToSample(subset);
classif.addAgent(new ParametersKmeans(ag1_seeds, nb_iter, weights), currentData);
classif.addAgent(new ParametersKmeans(ag2_seeds, nb_iter, weights), currentData); for (int a : ag_seeds) {
classif.addAgent(new ParametersKmeans(ag3_seeds, nb_iter, weights), currentData); classif.addAgent(new ParametersKmeans(a, nb_iter, distanceModel,
distanceParameters, averageParameters), currentData);
}
final String path_to_add = resultPath + datasetName + "/"; final String path_to_add = resultPath + datasetName + "/";
classif.setName(ImageHelper.stripExtension(filename)+"-"+i+ classif.setName(ImageHelper.stripExtension(filename)+"-"+i+
...@@ -239,16 +242,9 @@ public class TestA2CNESIterative { ...@@ -239,16 +242,9 @@ public class TestA2CNESIterative {
BufferedWriter bw = new BufferedWriter(fw); BufferedWriter bw = new BufferedWriter(fw);
PrintWriter out = new PrintWriter(bw); PrintWriter out = new PrintWriter(bw);
Distance[] distances = new Distance[1]; // a distance is set for every attribute
distances[0] = jcl.data.distance.sequential.DistanceDTWMD.getInstance(); // second attribute (sequential) compared with the DTW distance
MetaDistance metaDistance = MetaDistanceEuclidean.getInstance(); // defines the way the two scores are combined (possibility to weight)
Model model = new Model(distances, metaDistance);
int seqLength = ((AttributeMultiDimSequence) dataTest.getOneDataObject().getAttribute(0)).sequence.length;
DistanceParameter[] distanceParameters = new DistanceParameter[1];
distanceParameters[0] = new ParameterDTW(new double[seqLength][seqLength]); //but yes for DTW (requires a matrix to work in)
for (int i = 0 ; i < 5 ; i++) { for (int i = 0 ; i < 5 ; i++) {
int[] clustMap = classif.getClusteringResult().getClusterMap(); int[] clustMap = classif.getClusteringResult().getClusterMap();
int[] satisifiedMap = new int[constraints.size()]; int[] satisifiedMap = new int[constraints.size()];
for (int j = 0 ; j < constraints.size() ; j++) { for (int j = 0 ; j < constraints.size() ; j++) {
...@@ -274,10 +270,10 @@ public class TestA2CNESIterative { ...@@ -274,10 +270,10 @@ public class TestA2CNESIterative {
out.println(c.toString()+";"+ out.println(c.toString()+";"+
Constraint.marginalSilhouetteScore( Constraint.marginalSilhouetteScore(
ml.getFirstIndex(), classif.getClusteringResult(), ml.getFirstIndex(), classif.getClusteringResult(),
model, distanceParameters)+";"+ distanceModel , distanceParameters[0])+";"+
Constraint.marginalSilhouetteScore( Constraint.marginalSilhouetteScore(
ml.getSecondIndex(), classif.getClusteringResult(), ml.getSecondIndex(), classif.getClusteringResult(),
model, distanceParameters) distanceModel , distanceParameters[0])
); );
} else { } else {
...@@ -285,16 +281,31 @@ public class TestA2CNESIterative { ...@@ -285,16 +281,31 @@ public class TestA2CNESIterative {
out.println(c.toString()+";"+ out.println(c.toString()+";"+
Constraint.marginalSilhouetteScore( Constraint.marginalSilhouetteScore(
cl.getFirstIndex(), classif.getClusteringResult(), cl.getFirstIndex(), classif.getClusteringResult(),
model, distanceParameters)+";"+ distanceModel , distanceParameters[0])+";"+
Constraint.marginalSilhouetteScore( Constraint.marginalSilhouetteScore(
cl.getSecondIndex(), classif.getClusteringResult(), cl.getSecondIndex(), classif.getClusteringResult(),
model, distanceParameters) distanceModel , distanceParameters[0])
); );
} }
} }
classif.setAdvancedParameters(degradation, classRatio, solutionType, kIntern, classif.setAdvancedParameters(degradation, classRatio, solutionType, kIntern,
kExtern, unificationType, criterion, 95); kExtern, unificationType, criterion, 95);
classif.newIteration(subset); classif.newIteration(subset);
FileWriter fw2 = null;
try {
fw2 = new FileWriter("log/"+rand+"sat_cst"+classif.getName()+".log", true);
} catch (IOException e) {
e.printStackTrace();
}
BufferedWriter bw2 = new BufferedWriter(fw2);
PrintWriter out2 = new PrintWriter(bw2);
int countSat = 0;
for(Constraint c : subset) {
if (c.evaluate(classif.getClusteringResult()) == 1) {
countSat++;
}
}
out2.write(subset.size()+";"+countSat);
try { try {
new CSVResultWriter(classif, path_to_add + classif.getName()+"_"+(i+1)).write(); new CSVResultWriter(classif, path_to_add + classif.getName()+"_"+(i+1)).write();
......