Spark MLlib卷积神经网络(CNN)(附PPT下载)
在Github上看到了一个有意思的项目,在Spark MLlib上的卷积神经网络API,调用方法类似Keras。
下面是该项目的示例代码,几行简单的代码就实现了卷积神经网络(CNN):
import org.apache.spark.ml.dl._
import org.apache.spark.sql.SQLContext
val sqlContext = new SQLContext(sc)
val data = sqlContext.read.format("libsvm").load("path_to_dataset.txt")
val dataset = data.withColumnRenamed("label", "labels")
# Set up architecture for convolutional neural network
val model = new Sequential()
model.add(new Convolution2D(8, 1, 3, 3, 28, 28))
model.add(new Activation("relu"))
model.add(new Dropout(0.5))
model.add(new Dense(6272, 10))
model.add(new Activation("softmax"))
model.compile(loss="categorical_crossentropy",
optimizer=new Optimizer().adam(lr=.001),
metrics="Accuracy")
val trained = model.fit(dataset, num_iters=500)
项目地址:https://github.com/JeremyNixon/sparkdl
该项目在slideshare中有PPT介绍,PPT可加QQ群532836339下载。PPT截图和QQ群二维码如下: