Commit 9d02fac8 authored by balanche's avatar balanche

Modification de l'algorithme K-means : celui ci utilise maintenant un Model...

Modification de l'algorithme K-means : celui ci utilise maintenant un Model pour définir les distances à utiliser. Still WIP
parent 3749b1ec
......@@ -4,6 +4,10 @@ import jcl.data.attribute.Attribute;
import jcl.data.distance.Distance;
import jcl.data.distance.DistanceParameter;
import jcl.data.distance.MetaDistance;
import jcl.data.distance.MetaDistanceEuclidean;
import jcl.data.distance.NumericalEuclideanDistance;
import jcl.data.distance.sequential.DistanceDTW;
import jcl.data.sampling.Sampler;
/**
* This class represents a classification model.
......@@ -53,4 +57,37 @@ public class Model {
public void setMetaDistance(MetaDistance metaDistance) {
this.metaDistance = metaDistance;
}
/**
* Generate the default Model from one DataObject
* @param dataObject the DataObject from which it will create the default model
* @return the default model
*/
public static Model generateDefaultModel(DataObject dataObject) {
Distance[] distances=new Distance[dataObject.getNbAttributes()];
for (int i=0; i<distances.length;i++) {
if (dataObject.getAttribute(i).getTypeAttribute() == Attribute.SEQUENCE_ATTRIBUTE)
distances[i] = DistanceDTW.getInstance();//uses DTW distances for sequential attributes ...
else
distances[i]=NumericalEuclideanDistance.getInstance();//... and euclidian distances for numerical attributes
}
MetaDistance metaDistance = MetaDistanceEuclidean.getInstance(); //defines the way the two scores are combined, by default it is Euclidian
Model model = new Model(distances, metaDistance);
return model;
}
/**
* Generate a naive model which only uses Euclidian distance
* @param dataObject the DataObject from which it will create a model
* @return the naive model
*/
public static Model generateNaiveModel(DataObject dataObject) {
Distance[] distances=new Distance[dataObject.getNbAttributes()];
for (int i=0; i<distances.length;i++) {
distances[i]=NumericalEuclideanDistance.getInstance();// uses euclidian distance for every attribute
}
MetaDistance metaDistance = MetaDistanceEuclidean.getInstance(); // defines the way the two scores are combined, by default it is Euclidian
Model model = new Model(distances, metaDistance);
return model;
}
}
......@@ -140,15 +140,6 @@ public abstract class Sampler implements MemoryFlush, Cloneable, Serializable, P
*/
protected abstract void updatePercentage();
/**
* Return the element from the whole source data at a given index
*
* @param index
* the element index
* @return the element
*/
public abstract DataObject getDataObject(int index);
/**
* Return an iterator that retrieve the whole source data representation.
* This method should retrieve the data in the most straight forward manner,
......
......@@ -32,8 +32,8 @@ import jcl.weights.GlobalWeights;
public class Kmeans {
private static final int NB_OBJECTS = 1000;
private static final int SEQUENCE_LENGTH = 100;
private static final int NB_OBJECTS = 20;
private static final int SEQUENCE_LENGTH = 5;
private static final int NB_TURNS = 10;
private static final int NB_CLUSTERS = 2;
......
......@@ -8,6 +8,7 @@ import java.io.ObjectInputStream;
import java.io.ObjectOutputStream;
import java.util.ArrayList;
import java.util.Date;
import java.util.Iterator;
import java.util.Vector;
import javax.naming.Context;
......@@ -18,7 +19,16 @@ import jcl.Classification;
import jcl.clustering.Cluster;
import jcl.clustering.ClusteringResult;
import jcl.data.Data;
import jcl.data.DataObject;
import jcl.data.Model;
import jcl.data.SimpleData;
import jcl.data.attribute.Attribute;
import jcl.data.distance.Distance;
import jcl.data.distance.DistanceParameter;
import jcl.data.distance.MetaDistance;
import jcl.data.distance.MetaDistanceEuclidean;
import jcl.data.distance.NumericalEuclideanDistance;
import jcl.data.distance.sequential.DistanceEuclidean;
import jcl.jcld.interJcld;
import jcl.learning.LearningMethod;
import jcl.learning.LearningParameters;
......@@ -323,12 +333,12 @@ public class SingleClassification extends Classification {
}
@Override
public void classify() {
public void classify() {
/*
* Add all the processes to monitor in the list (they have to implement progressable)
*/
this.addProgressable(this.learningMethod);
this.addProgressable(null);// had null because this.learningResult is not defined yet
this.addProgressable(null);// add null because this.learningResult is not defined yet
if (this.useRmi() && !this.isMaclaw()) {
if (this.sshParameters == null) {
......
package jcl.learning.methods.monostrategy.kmeans;
import java.text.SimpleDateFormat;
import java.util.ArrayList;
import java.util.Date;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
......@@ -280,6 +282,7 @@ public class ClassifierKmeans extends LearningMethod {
@Override
public LearningResult learn(Data data) {
// System.out.println("début learn: " + new SimpleDateFormat("HH:mm:ss.SSS").format(new Date()));
this.progressM = this.getNbIters();
LearningResultKmeans result = new LearningResultKmeans(data, (ParametersKmeans) this.parameters, this.samples);
......@@ -345,7 +348,8 @@ public class ClassifierKmeans extends LearningMethod {
if (log) System.out.println("Alert:" + distanceGlobale + "->" + result.getDistanceGlobale());
}
}
// System.out.println("fin learn: " + new SimpleDateFormat("HH:mm:ss.SSS").format(new Date()));
endProgress();
return result;
}
......
......@@ -3,8 +3,10 @@ package jcl.learning.methods.monostrategy.kmeans;
import java.awt.Color;
import java.io.StringReader;
import java.rmi.server.RMIFailureHandler;
import java.text.SimpleDateFormat;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Date;
import java.util.Iterator;
import java.util.List;
import java.util.Scanner;
......@@ -18,7 +20,9 @@ import jcl.data.Model;
import jcl.data.attribute.Attribute;
import jcl.data.distance.Distance;
import jcl.data.distance.DistanceParameter;
import jcl.data.distance.EmptyDistanceParameter;
import jcl.data.distance.MetaDistance;
import jcl.data.distance.sequential.ParameterDTW;
import jcl.data.mask.DummyMask;
import jcl.data.mask.Mask;
import jcl.data.mask.MultiIDIntArrayMask;
......@@ -139,6 +143,7 @@ public class LearningResultKmeans extends LearningResult {
@Override
public ClusteringResult classify(Data data, boolean fromSample) {
// System.out.println("début classify: " + new SimpleDateFormat("HH:mm:ss.SSS").format(new Date()));
for (int i = 0 ; i < this.seeds.size() ; i++) {
if (this.seeds.get(i) instanceof LightHardSeed) {
((LightHardSeed) this.seeds.get(i)).setId(i);
......@@ -161,6 +166,7 @@ public class LearningResultKmeans extends LearningResult {
result = ClusteringResult.gerenerateDefaultClusteringResult(this, clusterMap, null, this.weights,
this.seeds.size(), null, data, this.qualityIndices, getColors());
}
// System.out.println("fin classify: " + new SimpleDateFormat("HH:mm:ss.SSS").format(new Date()));
return result;
}
......@@ -260,7 +266,6 @@ public class LearningResultKmeans extends LearningResult {
int clusterMap[] = new int[nbObjects];
int nbThreads = ((ParametersKmeans) this.params).nbThreads;
nbThreads = 1;
if (((ParametersKmeans) this.params).fuzzy) {
// pour chaque objet...
int i =0;
......@@ -313,77 +318,37 @@ public class LearningResultKmeans extends LearningResult {
*
* ((HardSeed) this.seeds.get(clusterMap[i])).addObject(obj); }
*/
ParametersKmeans params = (ParametersKmeans) this.params;
ParametersKmeans params = (ParametersKmeans) this.params;
if (params.parameters == null) {
OldThreadedAffectation[] tabThreads = new OldThreadedAffectation[nbThreads];
int nbObjectsPerThread = (int) Math.ceil((double)nbObjects / nbThreads);
//int threadedClusterMap [][] = new int[nbThreads][nbObjectsPerThread];
for (int th = 0; th < nbThreads - 1; th++) {
if (onSample) {
tabThreads[th] = new OldThreadedAffectation(clusterMap,
data.iterator(th * nbObjectsPerThread, (th + 1) * nbObjectsPerThread - 1),
th * nbObjectsPerThread, (th + 1) * nbObjectsPerThread - 1,
this.weights, this.seeds);
} else {
tabThreads[th] = new OldThreadedAffectation(clusterMap,
data.getWholeSourceDataObjects(th * nbObjectsPerThread, (th + 1) * nbObjectsPerThread - 1),
th * nbObjectsPerThread, (th + 1) * nbObjectsPerThread - 1,
this.weights, this.seeds);
}
tabThreads[th].start();
}
if (onSample) {
tabThreads[nbThreads - 1] = new OldThreadedAffectation(clusterMap,
data.iterator((nbThreads - 1) * nbObjectsPerThread, nbObjects - 1),
(nbThreads - 1) * nbObjectsPerThread, nbObjects - 1,
this.weights, this.seeds);
} else {
tabThreads[nbThreads - 1] = new OldThreadedAffectation(clusterMap,
data.getWholeSourceDataObjects((nbThreads - 1) * nbObjectsPerThread, nbObjects - 1),
(nbThreads - 1) * nbObjectsPerThread, nbObjects - 1,
this.weights, this.seeds);
}
tabThreads[nbThreads - 1].start();
for (int th = 0; th < nbThreads; th++) {
try {
tabThreads[th].join();
} catch (InterruptedException e) {
e.printStackTrace();
}
}
for (int th = 0; th < nbThreads; th++) {
distanceGlobale += tabThreads[th].getThreadGlobalDistance();
}
} else {
System.err.println("Parameters should not be null");
} else {
ThreadedAffectation[] tabThreads = new ThreadedAffectation[nbThreads];
int nbObjectsPerThread = nbObjects / nbThreads;
Iterator<DataObject> it;
for (int th = 0; th < nbThreads - 1; th++) {
if (onSample) {
tabThreads[th] = new ThreadedAffectation(clusterMap,
data.iterator(th * nbObjectsPerThread, (th + 1) * nbObjectsPerThread - 1),
data.getModel(), params.parameters[th]);
data.getModel(), params.parameters[th],
th * nbObjectsPerThread);
} else {
tabThreads[th] = new ThreadedAffectation(clusterMap,
data.getWholeSourceDataObjects(th * nbObjectsPerThread, (th + 1) * nbObjectsPerThread - 1),
data.getModel(), params.parameters[th]);
data.getModel(), params.parameters[th],
th * nbObjectsPerThread);
}
tabThreads[th].start();
}
if (onSample) {
tabThreads[nbThreads - 1] = new ThreadedAffectation(clusterMap,
data.iterator((nbThreads - 1) * nbObjectsPerThread, nbObjects - 1),
data.getModel(), params.parameters[nbThreads - 1]);
data.getModel(), params.parameters[nbThreads - 1],
(nbThreads - 1) * nbObjectsPerThread);
} else {
tabThreads[nbThreads - 1] = new ThreadedAffectation(clusterMap,
data.getWholeSourceDataObjects((nbThreads - 1) * nbObjectsPerThread, nbObjects - 1),
data.getModel(), params.parameters[nbThreads - 1]);
data.getModel(), params.parameters[nbThreads - 1],
(nbThreads - 1) * nbObjectsPerThread);
}
tabThreads[nbThreads - 1].start();
......@@ -409,6 +374,7 @@ public class LearningResultKmeans extends LearningResult {
}
}
endProgress();
return clusterMap;
}
......@@ -847,70 +813,6 @@ public class LearningResultKmeans extends LearningResult {
}
}
@Deprecated
private class OldThreadedAffectation extends Thread {
int[] clusterMap;
Iterator<DataObject> dataset;
int startIndex, stopIndex;
ClassificationWeights weights;
Vector<KmeansSeed> seeds;
double threadGlobalDistance = 0;
public OldThreadedAffectation(int[] clusterMap, Iterator<DataObject> iterator,
int startIndex, int stopIndex, ClassificationWeights weights, Vector<KmeansSeed> seeds) {
this.clusterMap = clusterMap;
this.dataset = iterator;
this.startIndex = startIndex;
this.stopIndex = stopIndex;
this.weights = (ClassificationWeights) weights.clone();
this.seeds = new Vector<KmeansSeed>();
for(KmeansSeed s : seeds) {
this.seeds.add((KmeansSeed) s.clone());
}
}
public void run() {
// PrintWriter writer = null;
// try {
// writer = new PrintWriter("/home/baptistelafabregue/Documents/test"+LearningResultKmeans.incement+".csv", "UTF-8");
// LearningResultKmeans.incement++;
// } catch (FileNotFoundException | UnsupportedEncodingException e) {
// e.printStackTrace();
// }
int i = startIndex;
while(i <= stopIndex) {
DataObject obj = dataset.next();
double distTemp;
double distMin = seeds.get(0).distance(obj, weights.getWeights(0));
clusterMap[i] = 0;
for (int j = 1; j < seeds.size(); j++) {
distTemp = seeds.get(j).distance(obj, weights.getWeights(j));
if (distTemp < distMin) {
distMin = distTemp;
clusterMap[i] = j;
}
}
//writer.println(threadGlobalDistance+";"+distMin+"; "+obj.getValue(0)+";"+seeds.get(clusterMap[i]).center.getValue(0));
threadGlobalDistance += distMin;
//((HardSeed) seeds.get(clusterMap[i])).addObject(obj);
incProgress();
i++;
}
// writer.close();
}
public double getThreadGlobalDistance() {
return threadGlobalDistance;
}
}
private class ThreadedAffectation extends Thread {
......@@ -920,9 +822,10 @@ public class LearningResultKmeans extends LearningResult {
DistanceParameter[] parameters;
DataObject[] seedsCopy;
int threadGlobalDistance = 0;
int startI = 0;
public ThreadedAffectation(int[] clusterMap, Iterator<DataObject> iterator, Model model,
DistanceParameter[] parameters) {
DistanceParameter[] parameters, int startI) {
this.clusterMap = clusterMap;
this.dataset = iterator;
this.model = model;
......@@ -931,11 +834,11 @@ public class LearningResultKmeans extends LearningResult {
for (int s = 0; s < seedsCopy.length; s++) {
seedsCopy[s] = (DataObject) seeds.get(s).getCenter().clone();
}
this.startI = startI;
}
public void run() {
int i = 0;
int i = this.startI;
while(dataset.hasNext()) {
DataObject obj = dataset.next();
MetaDistance metaDistance = model.getMetaDistance();
......@@ -952,6 +855,7 @@ public class LearningResultKmeans extends LearningResult {
}
distanceGlobale += distMin;
//((HardSeed) seeds.get(clusterMap[i])).addObject(obj);
incProgress();
i++;
}
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment