分享

转:开源一个基于Storm 分布式BP神经网络的Demo(Java版)

 hehffyy 2017-06-23
RT,这个Demo是我去年写,功能上没有大问题
该demo有3个缺点
1、代码性能一般,有一定的优化空间。由于是Demo,当时写的时候只考虑功能,没考虑性能
2、神经细胞层数被写死了,只有2层,如果要添加细胞层需要改源码。
3、训练任然是串行的,没有充分利用分布式框架并行计算优势。但是后续的计算任然是并行的,只要添加DRPC Client或者自定义DSpout就行了

帖子里我就不介绍Storm和神经网络原理了
我会单独写一篇博文介绍,预计本周内完成,连接:http://blog.csdn.net/tntzbzc/article/details/19974515

Spout启动+Drpc Client 训练++Drpc Client 计算
三个我放在一个main中了
第一个是服务端job 启动
第二个是训练(串行),测试的训练数据是任意整数,把它转成32个由0或1组成的double数组作为输入参数(input),以及一个double[4]的结果值(real)
real[4]:1 0 0 0 正奇数 0 1 0 0 正偶数 0 0 1 0 负奇数 0 0 0 1 负偶数
两层细胞的权重weight是随机生成的
随机生成了1000个样本,训练2000次

第三个是计算,如果想实现分布式并行计算,可以自己添加Client

public class DrpcClient {

	/****************************************
	 * by CSDN 撸大湿 Email : tntzhou@hotmail.com
	 ****************************************/

	public static void main(String[] args) throws Exception {
		int InputHideCount = 32; /* Input Hide输入数量 */
		int HideOutCount = 10; /* Hide Bolt输出数量,等于是Out Bolt的输入数量 */
		int OutCount = 4; /* Out输出数量,等于Real的数量 */

		TopologyBuilder builder = new TopologyBuilder();
		DRPCSpout drpcSpout = new DRPCSpout("BPTrain");
		builder.setSpout("drpcSpout", drpcSpout, 1);
		builder.setBolt("hide", new HideBolt(), HideOutCount).allGrouping("drpcSpout");

		// OutBolt的传参 必须等于 HideBolt的个数
		builder.setBolt("out", new OutBolt(HideOutCount), OutCount).allGrouping("hide");

		// TrainBPFinsh的传参 必须等于 OutBolt的个数
		builder.setBolt("finsh", new TrainBPFinsh(OutCount), 1).allGrouping("out");

		builder.setBolt("return", new ReturnResults(), 1).allGrouping("finsh");
		Config conf = new Config();
		conf.setNumWorkers(Integer.parseInt(args[1]));
		StormSubmitter.submitTopology(args[0], conf, builder.createTopology());

		String hideweight = getWeightStr(InputHideCount + 1, HideOutCount); // 隐藏层的权重
		String outweight = getWeightStr(HideOutCount + 1, OutCount);// 输出层的权重
		DRPCClient client = new DRPCClient("mynode001", 3772);
		int[] ranInt = new int[1000];
		for (int i = 0; i < ranInt.length; i++) {
			ranInt[i] = new java.util.Random().nextInt();
		}
		int Num = 0;
		System.out.println("开始训练");
		for (int i = 0; i < 2000; i++) {
			double r = 0d;
			for (int j = 0; j < ranInt.length; j++) {//
				RandomNum MyRandom = new RandomNum(ranInt[j]);
				String input = MyRandom.getInputDataStr();
				String real = MyRandom.getRealDataStr();

				String[] result = client.execute(
						"BPTrain",
						String.valueOf(Num) + "::" + input + "::" + real + "::" + hideweight + "::"
								+ outweight).split("::");
				// 参数传入全部靠一个字符串,收取也是字符串,最基本的DRPC Client调用
				hideweight = result[0];
				outweight = result[1];
				r = Double.parseDouble(result[2]);
				if (j % 100 == 0)
					System.out.println(r); // 输出收敛度
			}
			if (r < 0.005) {
				System.out.println("训练结束");
				break;
			}
		}

		while (true) {
			byte[] input = new byte[10];
			System.in.read(input);
			int value = 0;
			try {
				value = Integer.parseInt(new String(input).trim());
			} catch (Exception e) {
				break;
			}
			RandomNum rawVal = new RandomNum(value);

			String[] resultstr = client.execute(
					"BPTrain",
					String.valueOf(++Num) + "::" + rawVal.getInputDataStr() + "::"
							+ rawVal.getRealDataStr() + "::" + hideweight + "::" + outweight).split("::");

			double max = -Integer.MIN_VALUE;
			int idx = -1;
			String[] result = resultstr[3].split(",");
			for (int i = 0; i != result.length; i++) {
				if (Double.valueOf(result[i]) > max) {
					max = Double.valueOf(result[i]);
					idx = i;
				}
			}

			switch (idx) {
			case 0:
				System.out.format("%d是一个正奇数\n", value);
				break;
			case 1:
				System.out.format("%d是一个正偶数\n", value);
				break;
			case 2:
				System.out.format("%d是一个负奇数\n", value);
				break;
			case 3:
				System.out.format("%d是一个负偶数\n", value);
				break;
			}
		}
	}

	static String getWeightStr(int inCount, int outCount) {
		StringBuilder Wgt = new StringBuilder();
		for (int i = 0; i < outCount; i++) {
			for (int j = 0; j < inCount; j++) {
				Random random = new Random();
				double v = random.nextDouble();
				double rand = random.nextDouble() > 0.5 ? v : -v;
				Wgt.append(rand / 2);
				if (j != inCount - 1)
					Wgt.append(",");
			}
			if (i != outCount - 1)
				Wgt.append(":");
		}
		// System.out.println(Wgt.toString().split(":")[0].split(",").length);
		return Wgt.toString();
	}
}

隐藏层神经细胞 Hide Bolt
public class HideBolt implements IRichBolt {
	/****************************************
	 * by CSDN 撸大湿 Email : tntzhou@hotmail.com
	 ****************************************/
	private static final long serialVersionUID = -3242401692275116210L;
	OutputCollector collector;
	int TaskID = 0;

	@SuppressWarnings("rawtypes")
	@Override
	public void prepare(Map stormConf, TopologyContext context, OutputCollector _collector) {
		collector = _collector;
		TaskID = context.getThisTaskIndex(); // 每个Bolt代表一个神经细胞
															// Bolt TaskID == 细胞 id
	}

	@Override
	public void execute(Tuple tuple) {

		String[] t = tuple.getString(0).split("::");

		String jobID = tuple.getString(1); // jobid,drpc rid
		int TNum = Integer.valueOf(t[0]); // tuple id
		String[] inputstr = t[1].split(",");
		double[] input = new double[inputstr.length]; // 传入值
		String[] hideweightstr = t[3].split(":");
		double[][] hideweight = new double[hideweightstr.length][input.length + 1]; // 隐藏层的权重
		for (int i = 0; i < inputstr.length; i++) {
			input[i] = Double.parseDouble(inputstr[i]);
		}
		for (int i = 0; i < hideweightstr.length; i++) {
			String[] tmpw = hideweightstr[i].split(",");
			for (int j = 0; j < tmpw.length; j++) {
				hideweight[i][j] = Double.parseDouble(tmpw[j]);
			}
		}
		double inputsum = 0;

		/****************************************
		 * 计算输出
		 ****************************************/
		for (int i = 0; i < input.length; i++) {

			inputsum += hideweight[TaskID][i] * input[i];
		}

		inputsum += hideweight[TaskID][input.length];
		double HideOut = 1.0 / (1.0 + Math.exp(-inputsum));

		String[] realstr = t[2].split(",");
		double[] real = new double[realstr.length];

		for (int i = 0; i < realstr.length; i++) {
			real[i] = Double.parseDouble(realstr[i]);
		}

		String[] outweightstr = t[4].split(":");
		double[][] outweight = new double[outweightstr.length][outweightstr[0].split(",").length];

		for (int i = 0; i < outweightstr.length; i++) {
			String[] tmpw = outweightstr[i].split(",");
			for (int j = 0; j < tmpw.length; j++) {
				outweight[i][j] = Double.parseDouble(tmpw[j]);
			}
		}

		collector.emit(new Values(TNum, TaskID, input, HideOut, real, hideweight, outweight, jobID));
		collector.ack(tuple);
	}

	@Override
	public void declareOutputFields(OutputFieldsDeclarer declarer) {
		declarer.declare(new Fields("tnum", "HideTaskID", "Input", "HideOut", "Real", "HideWeight",
				"OutWeight", "jobID"));
	}

	@Override
	public void cleanup() {
	}

	@Override
	public Map<String, Object> getComponentConfiguration() {
		return null;
	}

}


//输出层神经细胞
public class OutBolt implements IRichBolt {

	/****************************************
	 * by CSDN 撸大湿 Email : tntzhou@hotmail.com
	 ****************************************/
	private static final long serialVersionUID = -7483206983562705977L;
	OutputCollector collector;
	int TaskID = 0;
	HashMap<Integer, double[]> HideOutMap = new HashMap<Integer, double[]>();
	HashMap<Integer, ArrayList<Tuple>> MyTuple = new HashMap<Integer, ArrayList<Tuple>>();
	int HideTaskCount = 0;

	public OutBolt(int _hidetaskcount) {
		this.HideTaskCount = _hidetaskcount;
	}

	@SuppressWarnings("rawtypes")
	@Override
	public void prepare(Map stormConf, TopologyContext context, OutputCollector _collector) {
		collector = _collector;
		TaskID = context.getThisTaskIndex();
	}

	@Override
	public void execute(Tuple tuple) {

		int TNum = tuple.getInteger(0);
		int HideTaskID = tuple.getInteger(1);
		double hideout = tuple.getDouble(3);
		if (!HideOutMap.containsKey(TNum)) {
			HideOutMap.put(TNum, new double[HideTaskCount]);
			MyTuple.put(TNum, new ArrayList<Tuple>());
		}
		HideOutMap.get(TNum)[HideTaskID] = hideout;
		MyTuple.get(TNum).add(tuple);
		double[] input = null;
		double[] real = null;
		double[][] hideweight = null;
		double[][] outweight = null;
		String jobID = null;

		outweight = (double[][]) tuple.getValue(6);
		jobID = tuple.getString(7);

		if (MyTuple.get(TNum).size() == HideTaskCount) {
			input = (double[]) tuple.getValue(2);
			real = (double[]) tuple.getValue(4);
			hideweight = (double[][]) tuple.getValue(5);

			/****************************************
			 * 计算输出
			 ****************************************/
			Double sum = 0d;
			for (int i = 0; i < HideOutMap.get(TNum).length; i++) {
				sum += HideOutMap.get(TNum)[i] * outweight[TaskID][i];
			}
			sum += outweight[TaskID][HideOutMap.get(TNum).length];
			double Out = 1.0 / (1.0 + Math.exp(-sum));
			collector.emit(new Values(TNum, TaskID, input, HideOutMap.get(TNum), Out, real,
					hideweight, outweight, jobID));
			HideOutMap.remove(TNum);
			MyTuple.remove(TNum);
		}
		collector.ack(tuple);
	}

	@Override
	public void declareOutputFields(OutputFieldsDeclarer declarer) {
		declarer.declare(new Fields("tnum", "OutTaskID", "Input", "HideOut", "Out", "Real",
				"HideWeight", "OutWeight", "jobID"));
	}

	@Override
	public void cleanup() {
		// TODO Auto-generated method stub

	}

	@Override
	public Map<String, Object> getComponentConfiguration() {
		// TODO Auto-generated method stub
		return null;
	}

}
BP反馈层,训练迭代终点,
public class TrainBPFinsh implements IRichBolt, FinishedCallback {

	/****************************************
	 * by CSDN 撸大湿 Email : tntzhou@hotmail.com
	 ****************************************/
	private static final long serialVersionUID = 5303881246503874591L;
	OutputCollector collector;
	int TaskID = 0;
	int OutTaskCount = 0;
	HashMap<Integer, double[]> OutMap = new HashMap<Integer, double[]>();
	HashMap<Integer, ArrayList<Tuple>> MyTuple = new HashMap<Integer, ArrayList<Tuple>>();

	public TrainBPFinsh(int _OutTaskCount) {
		OutTaskCount = _OutTaskCount;
	}

	@SuppressWarnings("rawtypes")
	@Override
	public void prepare(Map stormConf, TopologyContext context, OutputCollector _collector) {
		TaskID = context.getThisTaskIndex();
		collector = _collector;

	}

	@Override
	public void execute(Tuple tuple) {

		int TNum = tuple.getInteger(0);
		int OutTaskID = tuple.getInteger(1);
		double out = tuple.getDouble(4);

		if (!OutMap.containsKey(TNum)) {
			OutMap.put(TNum, new double[OutTaskCount]);
			MyTuple.put(TNum, new ArrayList<Tuple>());
		}
		OutMap.get(TNum)[OutTaskID] = out;
		MyTuple.get(TNum).add(tuple);
		double[] input = null;
		double[] hideout = null;
		double[] real = null;
		double[][] hideweight = null;
		double[][] outweight = null;
		String jobID = null;

		jobID = tuple.getString(8);
		if (MyTuple.get(TNum).size() == OutTaskCount) {
			input = (double[]) tuple.getValue(2);
			hideout = (double[]) tuple.getValue(3);
			real = (double[]) tuple.getValue(5);
			hideweight = (double[][]) tuple.getValue(6);
			outweight = (double[][]) tuple.getValue(7);
			jobID = tuple.getString(8);

			double[] outDelta = new double[OutTaskCount];
			double[] hideDelta = new double[hideout.length];
			double[] outout = OutMap.get(TNum);
			double rate = 0.5; // 收敛速率
			// 计算每个输出单元的误差项
			for (int i = 0; i < hideout.length; i++) {
				System.out.println("hideout " + i + " : " + hideout[i]);
			}
			for (int i = 0; i < outout.length; i++) {
				System.out.println("outout " + i + " : " + outout[i]);
			}
			double outErrSum = 0;
			for (int i = 0; i < OutTaskCount; ++i) {
				outDelta[i] = outout[i] * (1.0 - outout[i]) * (real[i] - outout[i]);
				System.out.println("real " + i + " : " + real[i]);
				System.out.println("outDelta " + i + " : " + outDelta[i]);
				outErrSum += Math.pow((real[i] - outout[i]), 2);
			}

			// 计算每个隐藏单元的误差项
			for (int i = 0; i < hideout.length; ++i) {
				double sum = 0;
				for (int j = 0; j < OutTaskCount; ++j) {
					sum += outweight[j][i] * outDelta[j];
				}
				hideDelta[i] = hideout[i] * (1d - hideout[i]) * sum;
				System.out.println("hideDelta " + i + " : " + hideDelta[i]);
			}

			StringBuilder hideweightstr = new StringBuilder();
			StringBuilder outweightstr = new StringBuilder();

			// 收敛输出层的权值
			for (int i = 0; i < OutTaskCount; ++i) {
				for (int j = 0; j < hideout.length; ++j) {
					outweight[i][j] += rate * outDelta[i] * hideout[j];
					outweightstr.append(outweight[i][j]);
					outweightstr.append(",");
				}

				outweight[i][hideout.length] += rate * outDelta[i];
				outweightstr.append(outweight[i][hideout.length]);
				if (i != OutTaskCount - 1) {
					outweightstr.append(":");
				}
			}

			// 收敛隐含层的权值
			for (int i = 0; i < hideout.length; ++i) {
				for (int j = 0; j < input.length; ++j) {
					hideweight[i][j] += rate * hideDelta[i] * input[j];
					hideweightstr.append(hideweight[i][j]);
					hideweightstr.append(",");
				}

				hideweight[i][input.length] += rate * hideDelta[i];
				hideweightstr.append(hideweight[i][input.length]);
				if (i != hideout.length - 1) {
					hideweightstr.append(":");
				}
			}
			String outStr = "";
			for (int i = 0; i < outout.length; i++) {
				outStr += String.valueOf(outout[i]);
				if (i != outout.length - 1)
					outStr += ",";
			}
			String result = hideweightstr + "::" + outweightstr + "::" + outErrSum + "::" + outStr;
			collector.emit(new Values(result, jobID));
			OutMap.remove(TNum);
			MyTuple.remove(TNum);
		}
		collector.ack(tuple);
	}

	@Override
	public void declareOutputFields(OutputFieldsDeclarer declarer) {
		declarer.declare(new Fields("result", "returnInfo"));

	}

	@Override
	public void cleanup() {
		// TODO Auto-generated method stub

	}

	@Override
	public Map<String, Object> getComponentConfiguration() {
		// TODO Auto-generated method stub
		return null;
	}

	@Override
	public void finishedId(Object id) {

	}

}
把代码打包成bpdemo.jar上传storm server,运行命令
storm jar /home/hadoop/bpdemo.jar master.demo.storm.DrpcClient TrainBP 1 -c nimbus.host=mynode001
开始训练


开始计算

storm集群UI

    本站是提供个人知识管理的网络存储空间,所有内容均由用户发布,不代表本站观点。请注意甄别内容中的联系方式、诱导购买等信息,谨防诈骗。如发现有害或侵权内容,请点击一键举报。
    转藏 分享 献花(0

    0条评论

    发表

    请遵守用户 评论公约

    类似文章 更多