weimin 发表于 2025-2-7 00:07:19

deeplearning4j~实现简单模型训练和测试

DeepLearning4j (DL4J) 是一个开源的深度学习库,专为 Java 和 Scala 设计。它可以用于构建、训练和部署深度学习模型。以下是关于如何使用 DL4J 的基本指南以及一个简单的模型训练示例。
本例中使用了MNIST数据集,MNIST(modified national institute of standard and technology)数据集是由Yann LeCun及其同事于1994年创建一个大型手写数字数据库(包含0~9十个数字)。MNIST数据集的原始数据来源于美国国家标准和技术研究院(national institute of standard and technology)的两个数据集:special database 1和special database 3。它们分别由NIST的员工和美国高中生手写的0-9的数字组成。原始的这两个数据集由128×128像素的黑白图像组成。LeCun等人将其进行归一化和尺寸调整后得到的是28×28的灰度图像。
DeepLearning4j 使用指南

安装与配置


[*]环境要求

[*]Java Development Kit (JDK) 8 或以上版本
[*]Maven(推荐)或 Gradle 用于项目管理

[*]创建 Maven 项目
在你的 IDE 中创建一个新的 Maven 项目,并在 pom.xml 文件中添加以下依赖:
<dependencies>    <!-- DL4J Core -->    <dependency>      <groupId>org.deeplearning4j</groupId>      <artifactId>deeplearning4j-core</artifactId>      <version>1.0.0-M1.1</version>    </dependency>    <!-- ND4J (Numpy for Java) -->    <dependency>      <groupId>org.nd4j</groupId>      <artifactId>nd4j-native-platform</artifactId>      <version>1.0.0-M1.1</version>    </dependency>    <!-- DataVec for data preprocessing -->    <dependency>      <groupId>org.datavec</groupId>      <artifactId>datavec-api</artifactId>      <version>1.0.0-M1.1</version>    </dependency></dependencies>
[*]更新 Maven 依赖
确保你的 IDE 更新了 Maven 依赖,下载所需的库。
简单的模型训练

下面是一个使用 DL4J 训练简单神经网络的示例,目标是对手写数字进行分类(MNIST 数据集)。
代码示例

import org.deeplearning4j.datasets.iterator.impl.MnistDataSetIterator;import org.deeplearning4j.nn.conf.MultiLayerConfiguration;import org.deeplearning4j.nn.conf.NeuralNetConfiguration;import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;import org.deeplearning4j.optimize.listeners.ScoreIterationListener;import org.nd4j.linalg.activations.Activation;import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;import org.nd4j.linalg.learning.config.Adam;import org.nd4j.linalg.lossfunctions.LossFunctions;public class MnistExample {    public static void main(String[] args) throws Exception {      // 加载 MNIST 数据集      DataSetIterator mnistTrain = new MnistDataSetIterator(128, true, 12345);                // 配置神经网络      MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder()                .seed(123)                .updater(new Adam(0.001))                .list()                .layer(0, new DenseLayer.Builder().nIn(784).nOut(256)                        .activation(Activation.RELU)                        .build())                .layer(1, new OutputLayer.Builder(LossFunctions.LossFunction.NEGATIVELOGLIKELIHOOD)                        .activation(Activation.SOFTMAX)                        .nIn(256).nOut(10).build())                .build();      // 创建并初始化网络      MultiLayerNetwork model = new MultiLayerNetwork(conf);      model.init();      model.setListeners(new ScoreIterationListener(100));// 每100次迭代输出一次分数      // 训练模型      for (int i = 0; i < 10; i++) { // 训练10个epoch            model.fit(mnistTrain);      }      System.out.println("训练完成!");      // 加载 MNIST 测试数据集      DataSetIterator mnistTest = new MnistDataSetIterator(128, false, 12345);      // 评估模型      double accuracy = model.evaluate(mnistTest).accuracy();      System.out.println("模型准确率: " + accuracy);      // 保存模型到文件      File modelFile = new File("mnist_model.zip");      ModelSerializer.writeModel(model, modelFile, true);    }}代码说明


[*]加载数据集:使用 MnistDataSetIterator 加载 MNIST 数据集。
[*]配置神经网络:

[*]使用 NeuralNetConfiguration.Builder 构建神经网络配置。
[*]添加输入层(DenseLayer)和输出层(OutputLayer)。

[*]创建和初始化模型:使用 MultiLayerNetwork 创建模型并初始化。
[*]训练模型:通过循环调用 fit() 方法训练模型。
运行示例

确保你的环境已正确设置,然后运行上述代码。模型将在 MNIST 数据集上进行训练,训练完成后会输出“训练完成!”的信息。
模型评估

在训练完模型后,通常需要对其进行评估,以了解模型在未见数据上的表现。你可以使用测试集来评估模型的准确性和其他性能指标。
保存与加载模型

训练完成后,你可能希望保存模型以便以后使用。DL4J 提供了简单的方法来保存和加载模型。
调整与优化模型

根据评估结果,你可能需要调整模型的超参数或架构。可以尝试以下方法:

[*]增加层数或节点数:增加模型的复杂性。
[*]改变学习率:试验不同的学习率以找到最佳值。
[*]使用不同的激活函数:例如,尝试 LeakyReLU 或 ELU。
[*]正则化:添加 Dropout 层或 L2 正则化以防止过拟合。
部署模型

如果你打算将模型应用于生产环境,可以考虑将其部署为服务。可以使用以下方式之一:

[*]REST API:将模型包装为 RESTful 服务,方便客户端调用。
[*]嵌入式应用:将模型嵌入到 Java 应用程序中,直接进行预测。
模型的测试

使用 Java 和 DeepLearning4j 来训练自己的手写数字图像(例如 0 到 9 的标准图像)是一个很好的项目。下面是一个简单的步骤指南,帮助你实现这个目标。
步骤概述


[*]准备数据:将你的数字图像准备为合适的格式。
[*]创建和配置模型:使用 DeepLearning4j 创建神经网络模型。
[*]训练模型:使用你的图像数据训练模型。
[*]评估和测试模型:验证模型的性能。
1. 准备数据

首先,你需要将你的 0-9 数字图像准备好。假设你有 10 张图像,每张图像都是 28x28 像素的灰度图像,并且它们存储在本地文件系统中。
模型测试的步骤

步骤 1: 使用 MNIST 数据集训练模型


[*]加载数据集:使用 MnistDataSetIterator 加载 MNIST 数据集。
[*]构建模型:根据你的需求,构建一个适合的神经网络模型。
[*]训练模型:使用 MNIST 数据集对模型进行训练。
[*]保存模型:将训练好的模型保存到文件中(例如,保存为 .zip 文件)。
步骤 2: 准备手写数字图片


[*]手写数字:自己手写一个数字 1,并拍照或扫描成图片。
[*]预处理图片:

[*]将图片转换为灰度图像。
[*]调整图片大小为 28x28 像素(MNIST 数据集中的标准尺寸)。
[*]对图像进行归一化处理(通常将像素值缩放到 范围内)。

步骤 3: 比较手写数字与 MNIST 数据集


[*]加载保存的模型:从 zip 文件中加载之前训练好的模型。
[*]预测手写数字:将预处理后的手写数字图片输入到模型中进行预测。
[*]输出结果:模型将输出手写数字的预测结果。你可以将这个结果与 MNIST 数据集中相应的标签进行比较。
注意事项


[*]数据预处理:确保手写数字的预处理方式与训练时一致,包括图像大小、颜色通道和归一化。
[*]模型评估:在比较之前,可以先在测试集上评估模型的性能,以确保其准确性。
[*]可视化结果:可以通过可视化工具(如 matplotlib)展示手写数字及其预测结果,以便更好地理解模型的表现。
示例代码

以下是一个简单的示例代码框架,展示了如何实现这些步骤

/** * @author lind * @date 2025/1/7 14:27 * @since 1.0.0 */public class MnistUtils {    /**   * 将图像转换为灰度图像   *   * @param original   * @return   */    private static BufferedImage convertToGrayscale(BufferedImage original) {      BufferedImage grayImage = new BufferedImage(original.getWidth(), original.getHeight(), BufferedImage.TYPE_BYTE_GRAY);      Graphics g = grayImage.getGraphics();      g.drawImage(original, 0, 0, null);      g.dispose();      return grayImage;    }    /**   * 调整图像大小   *   * @param original   * @param width   * @param height   * @return   */    private static BufferedImage resizeImage(BufferedImage original, int width, int height) {      Image scaledImage = original.getScaledInstance(width, height, Image.SCALE_SMOOTH);      BufferedImage resizedImage = new BufferedImage(width, height, BufferedImage.TYPE_BYTE_GRAY);      Graphics2D g2d = resizedImage.createGraphics();      g2d.drawImage(scaledImage, 0, 0, null);      g2d.dispose();      return resizedImage;    }    /**   * 加载图像   *   * @param fileName   * @return   */    public static INDArray loadGrayImg(String fileName) {      try {            // 1. 加载图片            BufferedImage originalImage = ImageIO.read(new File(fileName));            // 2. 转换为灰度图像            BufferedImage grayImage = convertToGrayscale(originalImage);            // 3. 调整大小为 28x28 像素            BufferedImage resizedImage = resizeImage(grayImage, 28, 28);            // 4. 进行归一化处理            return normalizeImage(resizedImage);      } catch (IOException e) {            e.printStackTrace();      }      return null;    }    /**   * 对图像进行归一化处理并生成 INDArray   *   * @param image   * @return   */    private static INDArray normalizeImage(BufferedImage image) {      int width = image.getWidth();      int height = image.getHeight();      double[] normalizedData = new double; // 创建一维数组      for (int y = 0; y < height; y++) {            for (int x = 0; x < width; x++) {                // 获取灰度值(0-255)                int grayValue = image.getRGB(x, y) & 0xFF; // 只取灰度部分                // 归一化到 范围                normalizedData = grayValue / 255.0; // 填充一维数组            }      }      // 将一维数组转换为 INDArray,并添加批次维度      INDArray indArray = Nd4j.create(normalizedData).reshape(1, 784); // reshape to       return indArray;    }}
public static void main(String[] args) throws IOException {      // 加载已训练的模型      MultiLayerNetwork model = MultiLayerNetwork.load(new File("E:\\github\\lind-deeplearning4j\\mnist_model.zip"), true);      // 测试图像路径      String testImagePath = "d:\\dlj4\\img\\";      // 假设你有10个测试图像,命名为 0.png 到 9.png,当我从MNIST数据集网站下载9张图片后,这个大模型确实可以给我识别出来      for (int i = 0; i <= 3; i++) {            String fileName = testImagePath + i + ".png";            System.out.println("fileName=" + fileName);            INDArray testImage = loadGrayImg(fileName);            INDArray output = model.output(testImage); // 进行预测            // 获取预测结果            int predictedClass = Nd4j.argMax(output, 1).getInt(0);            System.out.println("测试图像 " + i + " 的预测结果: " + predictedClass);      }    }模型测试结果,它会根据0-3的图片,将图片上面的数字分析出来,这个事实上是根据我们训练的MINIST数据集得到的结果
页: [1]
查看完整版本: deeplearning4j~实现简单模型训练和测试