RT,这个Demo是我去年写,功能上没有大问题 该demo有3个缺点 1、代码性能一般,有一定的优化空间。由于是Demo,当时写的时候只考虑功能,没考虑性能 2、神经细胞层数被写死了,只有2层,如果要添加细胞层需要改源码。 3、训练任然是串行的,没有充分利用分布式框架并行计算优势。但是后续的计算任然是并行的,只要添加DRPC Client或者自定义DSpout就行了 帖子里我就不介绍Storm和神经网络原理了 我会单独写一篇博文介绍,预计本周内完成,连接:http://blog.csdn.net/tntzbzc/article/details/19974515 Spout启动+Drpc Client 训练++Drpc Client 计算 三个我放在一个main中了 第一个是服务端job 启动 第二个是训练(串行),测试的训练数据是任意整数,把它转成32个由0或1组成的double数组作为输入参数(input),以及一个double[4]的结果值(real) real[4]:1 0 0 0 正奇数 0 1 0 0 正偶数 0 0 1 0 负奇数 0 0 0 1 负偶数 两层细胞的权重weight是随机生成的 随机生成了1000个样本,训练2000次 第三个是计算,如果想实现分布式并行计算,可以自己添加Client public class DrpcClient { /**************************************** * by CSDN 撸大湿 Email : tntzhou@hotmail.com ****************************************/ public static void main(String[] args) throws Exception { int InputHideCount = 32; /* Input Hide输入数量 */ int HideOutCount = 10; /* Hide Bolt输出数量,等于是Out Bolt的输入数量 */ int OutCount = 4; /* Out输出数量,等于Real的数量 */ TopologyBuilder builder = new TopologyBuilder(); DRPCSpout drpcSpout = new DRPCSpout("BPTrain"); builder.setSpout("drpcSpout", drpcSpout, 1); builder.setBolt("hide", new HideBolt(), HideOutCount).allGrouping("drpcSpout"); // OutBolt的传参 必须等于 HideBolt的个数 builder.setBolt("out", new OutBolt(HideOutCount), OutCount).allGrouping("hide"); // TrainBPFinsh的传参 必须等于 OutBolt的个数 builder.setBolt("finsh", new TrainBPFinsh(OutCount), 1).allGrouping("out"); builder.setBolt("return", new ReturnResults(), 1).allGrouping("finsh"); Config conf = new Config(); conf.setNumWorkers(Integer.parseInt(args[1])); StormSubmitter.submitTopology(args[0], conf, builder.createTopology()); String hideweight = getWeightStr(InputHideCount + 1, HideOutCount); // 隐藏层的权重 String outweight = getWeightStr(HideOutCount + 1, OutCount);// 输出层的权重 DRPCClient client = new DRPCClient("mynode001", 3772); int[] ranInt = new int[1000]; for (int i = 0; i < ranInt.length; i++) { ranInt[i] = new java.util.Random().nextInt(); } int Num = 0; System.out.println("开始训练"); for (int i = 0; i < 2000; i++) { double r = 0d; for (int j = 0; j < ranInt.length; j++) {// RandomNum MyRandom = new RandomNum(ranInt[j]); String input = MyRandom.getInputDataStr(); String real = MyRandom.getRealDataStr(); String[] result = client.execute( "BPTrain", String.valueOf(Num) + "::" + input + "::" + real + "::" + hideweight + "::" + outweight).split("::"); // 参数传入全部靠一个字符串,收取也是字符串,最基本的DRPC Client调用 hideweight = result[0]; outweight = result[1]; r = Double.parseDouble(result[2]); if (j % 100 == 0) System.out.println(r); // 输出收敛度 } if (r < 0.005) { System.out.println("训练结束"); break; } } while (true) { byte[] input = new byte[10]; System.in.read(input); int value = 0; try { value = Integer.parseInt(new String(input).trim()); } catch (Exception e) { break; } RandomNum rawVal = new RandomNum(value); String[] resultstr = client.execute( "BPTrain", String.valueOf(++Num) + "::" + rawVal.getInputDataStr() + "::" + rawVal.getRealDataStr() + "::" + hideweight + "::" + outweight).split("::"); double max = -Integer.MIN_VALUE; int idx = -1; String[] result = resultstr[3].split(","); for (int i = 0; i != result.length; i++) { if (Double.valueOf(result[i]) > max) { max = Double.valueOf(result[i]); idx = i; } } switch (idx) { case 0: System.out.format("%d是一个正奇数\n", value); break; case 1: System.out.format("%d是一个正偶数\n", value); break; case 2: System.out.format("%d是一个负奇数\n", value); break; case 3: System.out.format("%d是一个负偶数\n", value); break; } } } static String getWeightStr(int inCount, int outCount) { StringBuilder Wgt = new StringBuilder(); for (int i = 0; i < outCount; i++) { for (int j = 0; j < inCount; j++) { Random random = new Random(); double v = random.nextDouble(); double rand = random.nextDouble() > 0.5 ? v : -v; Wgt.append(rand / 2); if (j != inCount - 1) Wgt.append(","); } if (i != outCount - 1) Wgt.append(":"); } // System.out.println(Wgt.toString().split(":")[0].split(",").length); return Wgt.toString(); } } 隐藏层神经细胞 Hide Bolt public class HideBolt implements IRichBolt { /**************************************** * by CSDN 撸大湿 Email : tntzhou@hotmail.com ****************************************/ private static final long serialVersionUID = -3242401692275116210L; OutputCollector collector; int TaskID = 0; @SuppressWarnings("rawtypes") @Override public void prepare(Map stormConf, TopologyContext context, OutputCollector _collector) { collector = _collector; TaskID = context.getThisTaskIndex(); // 每个Bolt代表一个神经细胞 // Bolt TaskID == 细胞 id } @Override public void execute(Tuple tuple) { String[] t = tuple.getString(0).split("::"); String jobID = tuple.getString(1); // jobid,drpc rid int TNum = Integer.valueOf(t[0]); // tuple id String[] inputstr = t[1].split(","); double[] input = new double[inputstr.length]; // 传入值 String[] hideweightstr = t[3].split(":"); double[][] hideweight = new double[hideweightstr.length][input.length + 1]; // 隐藏层的权重 for (int i = 0; i < inputstr.length; i++) { input[i] = Double.parseDouble(inputstr[i]); } for (int i = 0; i < hideweightstr.length; i++) { String[] tmpw = hideweightstr[i].split(","); for (int j = 0; j < tmpw.length; j++) { hideweight[i][j] = Double.parseDouble(tmpw[j]); } } double inputsum = 0; /**************************************** * 计算输出 ****************************************/ for (int i = 0; i < input.length; i++) { inputsum += hideweight[TaskID][i] * input[i]; } inputsum += hideweight[TaskID][input.length]; double HideOut = 1.0 / (1.0 + Math.exp(-inputsum)); String[] realstr = t[2].split(","); double[] real = new double[realstr.length]; for (int i = 0; i < realstr.length; i++) { real[i] = Double.parseDouble(realstr[i]); } String[] outweightstr = t[4].split(":"); double[][] outweight = new double[outweightstr.length][outweightstr[0].split(",").length]; for (int i = 0; i < outweightstr.length; i++) { String[] tmpw = outweightstr[i].split(","); for (int j = 0; j < tmpw.length; j++) { outweight[i][j] = Double.parseDouble(tmpw[j]); } } collector.emit(new Values(TNum, TaskID, input, HideOut, real, hideweight, outweight, jobID)); collector.ack(tuple); } @Override public void declareOutputFields(OutputFieldsDeclarer declarer) { declarer.declare(new Fields("tnum", "HideTaskID", "Input", "HideOut", "Real", "HideWeight", "OutWeight", "jobID")); } @Override public void cleanup() { } @Override public Map<String, Object> getComponentConfiguration() { return null; } } //输出层神经细胞
public class OutBolt implements IRichBolt {
/****************************************
* by CSDN 撸大湿 Email : tntzhou@hotmail.com
****************************************/
private static final long serialVersionUID = -7483206983562705977L;
OutputCollector collector;
int TaskID = 0;
HashMap<Integer, double[]> HideOutMap = new HashMap<Integer, double[]>();
HashMap<Integer, ArrayList<Tuple>> MyTuple = new HashMap<Integer, ArrayList<Tuple>>();
int HideTaskCount = 0;
public OutBolt(int _hidetaskcount) {
this.HideTaskCount = _hidetaskcount;
}
@SuppressWarnings("rawtypes")
@Override
public void prepare(Map stormConf, TopologyContext context, OutputCollector _collector) {
collector = _collector;
TaskID = context.getThisTaskIndex();
}
@Override
public void execute(Tuple tuple) {
int TNum = tuple.getInteger(0);
int HideTaskID = tuple.getInteger(1);
double hideout = tuple.getDouble(3);
if (!HideOutMap.containsKey(TNum)) {
HideOutMap.put(TNum, new double[HideTaskCount]);
MyTuple.put(TNum, new ArrayList<Tuple>());
}
HideOutMap.get(TNum)[HideTaskID] = hideout;
MyTuple.get(TNum).add(tuple);
double[] input = null;
double[] real = null;
double[][] hideweight = null;
double[][] outweight = null;
String jobID = null;
outweight = (double[][]) tuple.getValue(6);
jobID = tuple.getString(7);
if (MyTuple.get(TNum).size() == HideTaskCount) {
input = (double[]) tuple.getValue(2);
real = (double[]) tuple.getValue(4);
hideweight = (double[][]) tuple.getValue(5);
/****************************************
* 计算输出
****************************************/
Double sum = 0d;
for (int i = 0; i < HideOutMap.get(TNum).length; i++) {
sum += HideOutMap.get(TNum)[i] * outweight[TaskID][i];
}
sum += outweight[TaskID][HideOutMap.get(TNum).length];
double Out = 1.0 / (1.0 + Math.exp(-sum));
collector.emit(new Values(TNum, TaskID, input, HideOutMap.get(TNum), Out, real,
hideweight, outweight, jobID));
HideOutMap.remove(TNum);
MyTuple.remove(TNum);
}
collector.ack(tuple);
}
@Override
public void declareOutputFields(OutputFieldsDeclarer declarer) {
declarer.declare(new Fields("tnum", "OutTaskID", "Input", "HideOut", "Out", "Real",
"HideWeight", "OutWeight", "jobID"));
}
@Override
public void cleanup() {
// TODO Auto-generated method stub
}
@Override
public Map<String, Object> getComponentConfiguration() {
// TODO Auto-generated method stub
return null;
}
} BP反馈层,训练迭代终点, public class TrainBPFinsh implements IRichBolt, FinishedCallback { /**************************************** * by CSDN 撸大湿 Email : tntzhou@hotmail.com ****************************************/ private static final long serialVersionUID = 5303881246503874591L; OutputCollector collector; int TaskID = 0; int OutTaskCount = 0; HashMap<Integer, double[]> OutMap = new HashMap<Integer, double[]>(); HashMap<Integer, ArrayList<Tuple>> MyTuple = new HashMap<Integer, ArrayList<Tuple>>(); public TrainBPFinsh(int _OutTaskCount) { OutTaskCount = _OutTaskCount; } @SuppressWarnings("rawtypes") @Override public void prepare(Map stormConf, TopologyContext context, OutputCollector _collector) { TaskID = context.getThisTaskIndex(); collector = _collector; } @Override public void execute(Tuple tuple) { int TNum = tuple.getInteger(0); int OutTaskID = tuple.getInteger(1); double out = tuple.getDouble(4); if (!OutMap.containsKey(TNum)) { OutMap.put(TNum, new double[OutTaskCount]); MyTuple.put(TNum, new ArrayList<Tuple>()); } OutMap.get(TNum)[OutTaskID] = out; MyTuple.get(TNum).add(tuple); double[] input = null; double[] hideout = null; double[] real = null; double[][] hideweight = null; double[][] outweight = null; String jobID = null; jobID = tuple.getString(8); if (MyTuple.get(TNum).size() == OutTaskCount) { input = (double[]) tuple.getValue(2); hideout = (double[]) tuple.getValue(3); real = (double[]) tuple.getValue(5); hideweight = (double[][]) tuple.getValue(6); outweight = (double[][]) tuple.getValue(7); jobID = tuple.getString(8); double[] outDelta = new double[OutTaskCount]; double[] hideDelta = new double[hideout.length]; double[] outout = OutMap.get(TNum); double rate = 0.5; // 收敛速率 // 计算每个输出单元的误差项 for (int i = 0; i < hideout.length; i++) { System.out.println("hideout " + i + " : " + hideout[i]); } for (int i = 0; i < outout.length; i++) { System.out.println("outout " + i + " : " + outout[i]); } double outErrSum = 0; for (int i = 0; i < OutTaskCount; ++i) { outDelta[i] = outout[i] * (1.0 - outout[i]) * (real[i] - outout[i]); System.out.println("real " + i + " : " + real[i]); System.out.println("outDelta " + i + " : " + outDelta[i]); outErrSum += Math.pow((real[i] - outout[i]), 2); } // 计算每个隐藏单元的误差项 for (int i = 0; i < hideout.length; ++i) { double sum = 0; for (int j = 0; j < OutTaskCount; ++j) { sum += outweight[j][i] * outDelta[j]; } hideDelta[i] = hideout[i] * (1d - hideout[i]) * sum; System.out.println("hideDelta " + i + " : " + hideDelta[i]); } StringBuilder hideweightstr = new StringBuilder(); StringBuilder outweightstr = new StringBuilder(); // 收敛输出层的权值 for (int i = 0; i < OutTaskCount; ++i) { for (int j = 0; j < hideout.length; ++j) { outweight[i][j] += rate * outDelta[i] * hideout[j]; outweightstr.append(outweight[i][j]); outweightstr.append(","); } outweight[i][hideout.length] += rate * outDelta[i]; outweightstr.append(outweight[i][hideout.length]); if (i != OutTaskCount - 1) { outweightstr.append(":"); } } // 收敛隐含层的权值 for (int i = 0; i < hideout.length; ++i) { for (int j = 0; j < input.length; ++j) { hideweight[i][j] += rate * hideDelta[i] * input[j]; hideweightstr.append(hideweight[i][j]); hideweightstr.append(","); } hideweight[i][input.length] += rate * hideDelta[i]; hideweightstr.append(hideweight[i][input.length]); if (i != hideout.length - 1) { hideweightstr.append(":"); } } String outStr = ""; for (int i = 0; i < outout.length; i++) { outStr += String.valueOf(outout[i]); if (i != outout.length - 1) outStr += ","; } String result = hideweightstr + "::" + outweightstr + "::" + outErrSum + "::" + outStr; collector.emit(new Values(result, jobID)); OutMap.remove(TNum); MyTuple.remove(TNum); } collector.ack(tuple); } @Override public void declareOutputFields(OutputFieldsDeclarer declarer) { declarer.declare(new Fields("result", "returnInfo")); } @Override public void cleanup() { // TODO Auto-generated method stub } @Override public Map<String, Object> getComponentConfiguration() { // TODO Auto-generated method stub return null; } @Override public void finishedId(Object id) { } } 把代码打包成bpdemo.jar上传storm server,运行命令 开始训练
|
|