首页 > 其他分享 >解析lightgbm的txt模型文件

解析lightgbm的txt模型文件

时间:2024-11-08 13:57:28浏览次数:1  
标签:jpmml 解析 lightgbm log pmml org import evaluator txt

根据近期的github方案,实现对txt格式的pmml文件的加载

添加依赖

<dependency>
    <groupId>org.jpmml</groupId>
    <artifactId>pmml-lightgbm</artifactId>
    <version>1.5.4</version>
</dependency>
<dependency>
    <groupId>org.jpmml</groupId>
    <artifactId>pmml-evaluator</artifactId>
    <version>1.6.6</version>
</dependency>
<dependency>
    <groupId>org.jpmml</groupId>
    <artifactId>pmml-model</artifactId>
    <version>1.6.6</version>
</dependency>

工具类

import lombok.extern.slf4j.Slf4j;
import org.dmg.pmml.PMML;
import org.jpmml.evaluator.FieldValue;
import org.jpmml.evaluator.InputField;
import org.jpmml.evaluator.ModelEvaluator;
import org.jpmml.evaluator.ModelEvaluatorBuilder;
import org.jpmml.evaluator.ModelEvaluatorFactory;
import org.jpmml.evaluator.ProbabilityDistribution;
import org.jpmml.evaluator.TargetField;
import org.jpmml.lightgbm.GBDT;
import org.jpmml.lightgbm.HasLightGBMOptions;
import org.jpmml.lightgbm.LightGBMUtil;
import org.jpmml.model.metro.MetroJAXBUtil;
import org.springframework.core.io.ClassPathResource;
import org.springframework.core.io.Resource;

import java.io.File;
import java.io.FileOutputStream;
import java.io.InputStream;
import java.io.OutputStream;
import java.util.HashMap;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import java.util.stream.Collectors;

/**
 * 加载、初始化 PMML模型文件 :
 * 依赖 pmml-lightgbm-1.5.0(AGPL-3.0 License)
 * <p>
 * 解析PMML文件 @link https://github.com/jpmml/jpmml-lightgbm
 * 生成evaluator @link https://github.com/jpmml/jpmml-evaluator
 */
@Slf4j
public class LightgbmTxtInitializer {

    // description = "Custom objective function"
    private static String objectiveFunction = null;
    // description = "Transform LightGBM-style trees to PMML-style trees",
    private static boolean compact = true;
    // description = "Treat Not-a-Number (NaN) values as missing values",
    private static boolean nanAsMissing = true;
    // description = "Limit the number of trees. Defaults to all trees"
    private static Integer numIteration = null;
    // description = "Target name. Defaults to \"_target\""
    private static String targetName = null;
    // description = "Target categories. Defaults to 0-based index [0, 1, .., num_class - 1]"
    private static List<String> targetCategories = null;


    public static void main(String[] output) throws Exception {
        Resource resource = new ClassPathResource("lightgbm_model.txt");
        InputStream pmmlFileInputStream = resource.getInputStream();
        // 生成模型执行器
        ModelEvaluator evaluator = initEvaluator(pmmlFileInputStream);
        // 打印特征参数
        List<InputField> inputFields = evaluator.getInputFields();
        log.info("ModelEvaluator featureNames:" + inputFields);
        // 调试执行预测
        Map<String, Number> waitPreSample = new HashMap<>(8);
        waitPreSample.put("0", 0.1);
        waitPreSample.put("1", 0.2);
        waitPreSample.put("2", 0.3);
        String predictedValue = getPredictedValue(waitPreSample, evaluator);

        pmmlFileInputStream.close();
    }

    public static ModelEvaluator initEvaluator(InputStream pmmlFileInputStream) throws Exception {
        GBDT gbdt;
        long begin = System.currentTimeMillis();

        gbdt = LightGBMUtil.loadGBDT(pmmlFileInputStream);
        log.info("Loaded GBDT in {} ms.", (System.currentTimeMillis() - begin));

        if (objectiveFunction != null) {
            log.info("Setting custom objective function");
            gbdt.setObjectiveFunction(LightGBMUtil.parseObjectiveFunction(objectiveFunction));
        }
        Map<String, Object> options = new LinkedHashMap<>();
        options.put(HasLightGBMOptions.OPTION_COMPACT, compact);
        options.put(HasLightGBMOptions.OPTION_NAN_AS_MISSING, nanAsMissing);
        options.put(HasLightGBMOptions.OPTION_NUM_ITERATION, numIteration);

        // 生成标准PMML
        begin = System.currentTimeMillis();
        PMML pmml;
        pmml = gbdt.encodePMML(options, targetName, targetCategories);
        long end = System.currentTimeMillis();
        log.info("Converted GBDT to PMML in {} ms.", (System.currentTimeMillis() - begin));

        // no need
        // 输出PMML格式文件
        begin = System.currentTimeMillis();
        File outputFile = new File("E://t.pmml");
        OutputStream os = new FileOutputStream(outputFile);
        MetroJAXBUtil.marshalPMML(pmml, os);
        log.info("Marshalled PMML in {} ms.", (System.currentTimeMillis() - begin));

        // 生成evaluator
        begin = System.currentTimeMillis();
        ModelEvaluatorFactory modelEvaluatorFactory = ModelEvaluatorFactory.newInstance();
        ModelEvaluatorBuilder modelEvaluatorBuilder = new ModelEvaluatorBuilder(pmml);
        modelEvaluatorBuilder.setModelEvaluatorFactory(modelEvaluatorFactory);
        ModelEvaluator<?> evaluator = modelEvaluatorBuilder.build();
        evaluator.verify();
        log.info("Init evaluator in {} ms.", (System.currentTimeMillis() - begin));
        return evaluator;
    }


    public static String getPredictedValue(Map<String, ?> argumentMap,
                                           ModelEvaluator<?> evaluator) {
        // 预测计算
        Map<String, ?> evaluateResult = evaluator.evaluate(argumentMap);
        log.info("evaluateResult:" + evaluateResult);
        // 提取预测结果
        String predictedValue = null;
        TargetField targetFieldName = evaluator.getTargetField();
        Object targetFieldValue = evaluateResult.get(targetFieldName.getFieldName());
        // 输出预测结果
        if (targetFieldValue instanceof ProbabilityDistribution) {
            predictedValue = ((ProbabilityDistribution<?>) targetFieldValue).getPrediction().toString();
            log.info("Predicted value(ProbabilityDistribution) : " + predictedValue);
        } else if (targetFieldValue instanceof FieldValue) {
            FieldValue fieldValue = (FieldValue) targetFieldValue;
            predictedValue = fieldValue.asString();
            log.info("Predicted value(FieldValue) : " + predictedValue);
        } else if (targetFieldValue instanceof List) {
            List<String> resultList =
                    ((List<?>) targetFieldValue)
                            .stream()
                            .map(e -> ((FieldValue) e).asString())
                            .collect(Collectors.toList());
            predictedValue = String.join(",", resultList);
            log.info("Predicted value(List) : " + predictedValue);
        } else {
            log.error("unknown type for targetFieldValue:" + targetFieldValue);
        }
        return predictedValue;
    }
}

标签:jpmml,解析,lightgbm,log,pmml,org,import,evaluator,txt
From: https://www.cnblogs.com/yu007/p/18534935

相关文章

  • 2个月搞定计算机二级C语言——真题(10)解析qg
    合集-3个月搞定计算机二级C语言(6)1.2个月搞定计算机二级C语言——真题(5)解析10-292.2个月搞定计算机二级C语言——真题(6)解析10-303.2个月搞定计算机二级C语言——真题(7)解析11-034.2个月搞定计算机二级C语言——真题(8)解析11-035.2个月搞定计算机二级C语言——真题(9)解析11-06:Flow......
  • 理解Web登录机制:会话管理与跟踪技术解析(四)-拦截器Interceptor、异常处理
    本文将详细探讨如何通过拦截器实现登录校验,并介绍如何通过异常处理来确保系统的鲁棒性。我们将通过具体的示例,深入分析如何在Spring框架中配置拦截器与异常处理,以便为开发者提供一套高效、安全的登录校验和异常管理方案。目录前言拦截器Interceptor快速入门Interceptor......
  • 2个月搞定计算机二级C语言——真题(10)解析
    1.前言本篇我们讲解2个月搞定计算机二级C语言——真题102.程序填空题2.1题目要求2.2提供的代码#include<stdio.h>#pragmawarning(disable:4996)doublefun(doublex[],intn){ inti,k=0; doubleavg=0.0,sum=0.0; for(i=0;i<n;i++) avg......
  • 【LGBM】LightGBM sklearn API超参数解释与使用方法(优化)
            接下来我们进一步解释LGBM的sklearnAPI中各评估器中的超参数及使用方法。  在LGBM的sklearnAPI中,总共包含四个模型类(也就是四个评估器),分别是lightgbm.LGBMModel、LGBMClassifier和LGBMRegressor以及LGBMRanker:LGBMModel  LGBMModel是LightGBM的......
  • 高强度低合金结构钢SA572Gr42、SA572Gr50、SA572Gr55、SA572Gr60、SA572Gr65解析与解
    1适用范围1.1本标准包括5个级别的高强度低合金结构钢SA572Gr42、SA572Gr50、SA572Gr55、SA572Gr60、SA572Gr65的型材、钢板、薄板桩和棒材。级别42(290),50(345)和55(380)拟用于铆接、螺栓连接的或焊接的结构用途。级别60(415),和65(450)拟用于桥梁上的铆接、或螺栓连接......
  • 并发编程(8)—— std::async、std::future 源码解析
    文章目录八、day81.std::async2.std::future2.1wait()2.2get()八、day8之前说过,std::async内部的处理逻辑和std::thread相似,而且std::async和std::future有密不可分的联系。今天,通过对std::async和std::future源码进行解析,了解二者的处理逻辑和关系。源码均基......
  • P10954 LCIS 题目解析
    P10954LCIS题目解析题目链接思路前置:弱化版没什么好说的,设\(f_{i,j}\)表示\(a\)的前\(i\)个并且结尾为\(b_j\)的最长上升公共子序列。定义\(a_0=b_0=-\infty.\)转移:\(a_i=b_j,f_{i,j}=\max_{k\in[0,j-1]\text{且}b_k<a_i}f_{i-1,k}.\)否则,\(f_{i,j}=f_{......
  • SP15637 GNYR04H - Mr Youngs Picture Permutations 解析
    SP15637GNYR04H-MrYoungsPicturePermutations解析题目链接分析题目性质大意就是给\(k\)排然后每个数列单调,每个横列单调,求满足这样排列的方案数。我们发现:与其为每个位置分配某个学生不如考虑将每个学生分给某个位置。思路根据以上,不妨设:\(f_{a_1,a_2,a_3,a_4,a_5}......
  • WPF 中 NavigationWindow 与 Page 的继承关系解析
    官网解析:NavigationWindow类   |    Page类publicclassBaseWindow:NavigationWindow{}publicpartialclassCountPage:Page{}都是创建的WPF界面有什么区别?在WPF(WindowsPresentationFoundation)开发中,我们经常需要设计具有多个页面的应用程序。在......
  • 实用GIS工具箱对比:GISBox等倾斜摄影切片软件的优缺点解析
    在地理信息系统(GIS)领域,强大的工具可以帮助用户更高效地进行数据处理、分析和可视化。本文介绍了五款实用的GIS工具箱——GISBox、QGIS、ArcGISOnline、GlobalMapper、MapTiler。它们各自在数据编辑、格式转换、地图发布和切片等方面具有独特的功能,能够满足从地理数据管理到复杂......