两种DBSCAN算法的Java实现

首先构建Point类,点的信息也基本是从网上看了别人的博客直接复制的,但发现或多或少他们的代码总有些问题,于是自己直接就心血来潮分别实现了

package crazyjava.dbscan;

import java.util.Set;

public class Point {
    private double[] value;
    private boolean visited;
    private int cluster;


    public Point() {
    }

    public Point(double[] value, boolean visited, int cluster) {
        this.value = value;
        this.visited = visited;
        this.cluster = cluster;
    }


    public double[] getValue() {
        return value;
    }

    public void setValue(double[] value) {
        this.value = value;
    }

    public boolean isVisited() {
        return visited;
    }


    public void setVisited(boolean visited) {
        this.visited = visited;
    }

    public int getCluster() {
        return cluster;
    }

    public void setCluster(int cluster) {
        this.cluster = cluster;
    }

    @Override
    public String toString() {
        return value[0] + "," + value[1] + "的类别:" + cluster;
    }

    public static void initpointset(Set<Point> pointSet) {
        double[][] points = {
                {3.0, 8.04},
                {4.0, 7.95},
                {4.4, 8.58},
                {3.6, 8.81},
                {5.0, 8.33},
                {6.0, 6.96},
                {17.0, 4.24},
                {18.0, 4.26},
                {16.0, 3.84},
                {17.0, 4.82},
                {15.0, 5.68},
                {17.0, 5.68},
                {11.0, 10.68},
                {13.0, 9.68},
                {11.8, 10.0},
                {12.0, 11.18},
                {8.0, 12.0},
                {9.2, 9.68},
                {8.8, 11.2},
                {10.0, 11.4},
                {7.0, 9.68},
                {6.1, 10.68},
                {5.70, 1.68},
                {5.0, 2.68},
                {12.0, 0.68}
        };
        for (double[] point : points) {
            pointSet.add(new Point(point, false, 0));
        }
    }

    public static double distance(Point p1, Point p2) {
        double[] p1a = p1.getValue();
        double[] p2a = p2.getValue();
        double distance = 0.0;

        for (int i = 0; i < p1a.length; i++) {
            distance += Math.pow(p1a[i] - p2a[i], 2);
        }
        distance = Math.sqrt(distance);
        //System.out.println(p1a[0] + " " + p1a[1] + " " + p2a[0] + " " + p2a[1] + "的distance是" + distance);
        return distance;
    }

}

算法一,具体细节我就没调,但基本算法肯定实现了

package crazyjava.dbscan;

import java.util.HashSet;
import java.util.Set;
import java.util.concurrent.atomic.AtomicInteger;

import static crazyjava.dbscan.Point.distance;
import static crazyjava.dbscan.Point.initpointset;

/**
 * @Author unclewang
 * @Date 2018/4/19 21:41
 */
public class DBScanCode {
    private static Set<Point> pointSet = new HashSet<>();
    private static int minpts = 5;
    private static double radius = 2.5;
    private static Set<Point> cores = new HashSet<>();


    public static Set<Point> findcores() {
        pointSet.forEach(point -> {
            AtomicInteger id = new AtomicInteger();
            pointSet.forEach(point1 -> {
                if (distance(point1, point) < radius) {
                    id.getAndIncrement();
                }
            });
            if (id.get() > minpts) {
                cores.add(point);
            }
        });
        return cores;
    }

    public static void setCores(Set<Point> cores) {
        int cluid = 0;
        for (Point point : cores) {
            if (point.isVisited() == true) {
                continue;
            }
            cluid++;
            point.setCluster(cluid);

            denstiyConnected(point, cluid);
        }
    }

    public static void denstiyConnected(Point point, int id) {
        System.out.println(point.toString());
        point.setVisited(true);
        for (Point point1 : pointSet) {
            if (point1.isVisited() == true) {
                continue;
            }
            if (distance(point1, point) < radius) {
                point1.setCluster(id);
                point1.setVisited(true);
                if (cores.contains(point1)) {
                    denstiyConnected(point1, id);
                }
            }
        }

    }

    public static void main(String[] args) {
        initpointset(pointSet);
        Set<Point> cores = findcores();
        System.out.println(cores.size());
        setCores(cores);
        System.out.println("DBScan 打印结果");
        pointSet.forEach(point -> {
            if (point.getCluster() == 0) {
                System.out.println("噪声点" + point.toString());
            } else System.out.println(point.toString());
        });
    }

}

算法二

package crazyjava.dbscan;


import java.util.ArrayList;
import java.util.HashSet;
import java.util.List;
import java.util.Set;

import static crazyjava.dbscan.Point.distance;
import static crazyjava.dbscan.Point.initpointset;

public class DBScanCode2 {
    private static Set<Point> pointSet = new HashSet<>();
    private static ArrayList<Point> pointlist = new ArrayList<>();
    private static int minpts = 5;
    private static double radius = 2.5;

    public static void nostopping(ArrayList<Point> pointlist) {
        int clusterid = 1;
        for (Point point : pointlist) {
            if (point.isVisited()) {
                continue;
            }
            point.setVisited(true);
            List<Point> points = findset(point);
            if (points.size() <= minpts) {
                continue;
            } else {
                point.setCluster(clusterid);
                for (Point p : points) {
                    p.setVisited(true);
                    p.setCluster(point.getCluster());
                    List<Point> unpoints = findset(p);
                    if (unpoints.size() > minpts) {
                        for (Point up : unpoints) {
                            up.setVisited(true);
                            up.setCluster(point.getCluster());
                        }
                    }
                }
            }
            clusterid++;
        }
    }

    private static ArrayList<Point> findset(Point point) {
        ArrayList<Point> set = new ArrayList<>();
        pointSet.forEach(point1 -> {
            if (distance(point1, point) < radius) {
                set.add(point1);
            }
        });
        return set;
    }

    public static void main(String[] args) {
        initpointset(pointSet);
        pointlist.addAll(pointSet);
        nostopping(pointlist);
        System.out.println("DBScan2 打印结果");
        pointSet.forEach(point -> {
            if (point.getCluster() == 0) {
                System.out.println("噪声点" + point.toString());
            } else System.out.println(point.toString());
        });
    }
}

看伪代码写的感觉真的挺好,希望日后也能写出很简洁的伪代码

发表评论

电子邮件地址不会被公开。