首页 > 其他分享 >rasa train nlu详解:1.2-_train_graph()函数

rasa train nlu详解:1.2-_train_graph()函数

时间:2023-11-11 23:14:04浏览次数:64  
标签:training nlu 1.2 rasa graph train model

  本文使用《使用ResponseSelector实现校园招聘FAQ机器人》中的例子,主要详解介绍_train_graph()函数中变量的具体值。

一.rasa/model_training.py/_train_graph()函数
  _train_graph()函数实现,如下所示:

def _train_graph(
    file_importer: TrainingDataImporter,
    training_type: TrainingType,
    output_path: Text,
    fixed_model_name: Text,
    model_to_finetune: Optional[Union[Text, Path]] = None,
    force_full_training: bool = False,
    dry_run: bool = False,
    **kwargs: Any,
) -> TrainingResult:
    if model_to_finetune:  # 如果有模型微调
        model_to_finetune = rasa.model.get_model_for_finetuning(model_to_finetune)  # 获取模型微调
        if not model_to_finetune:  # 如果没有模型微调
            rasa.shared.utils.cli.print_error_and_exit(  # 打印错误并退出
                f"No model for finetuning found. Please make sure to either "   # 没有找到微调模型。请确保
                f"specify a path to a previous model or to have a finetunable " # 要么指定一个以前模型的路径,要么有一个可微调的
                f"model within the directory '{output_path}'."                  # 在目录'{output_path}'中的模型。
            )

        rasa.shared.utils.common.mark_as_experimental_feature(  # 标记为实验性功能
            "Incremental Training feature"  # 增量训练功能
        )

    is_finetuning = model_to_finetune is not None  # 如果有模型微调

    config = file_importer.get_config()  # 获取配置
    recipe = Recipe.recipe_for_name(config.get("recipe"))  # 获取配方
    config, _missing_keys, _configured_keys = recipe.auto_configure(  # 自动配置
        file_importer.get_config_file_for_auto_config(),  # 获取自动配置的配置文件
        config,  # 配置
        training_type,  # 训练类型
    )
    model_configuration = recipe.graph_config_for_recipe(  # 配方的graph配置
        config,  # 配置
        kwargs,  # 关键字参数
        training_type=training_type,  # 训练类型
        is_finetuning=is_finetuning,  # 是否微调
    )
    rasa.engine.validation.validate(model_configuration)  # 验证

    tempdir_name = rasa.utils.common.get_temp_dir_name()  # 获取临时目录名称

    # Use `TempDirectoryPath` instead of `tempfile.TemporaryDirectory` as this leads to errors on Windows when the context manager tries to delete an already deleted temporary directory (e.g. https://bugs.python.org/issue29982)
    # 翻译:使用TempDirectoryPath而不是tempfile.TemporaryDirectory,因为当上下文管理器尝试删除已删除的临时目录时,这会导致Windows上的错误(例如https://bugs.python.org/issue29982)
    with rasa.utils.common.TempDirectoryPath(tempdir_name) as temp_model_dir:  # 临时模型目录
        model_storage = _create_model_storage(  # 创建模型存储
            is_finetuning, model_to_finetune, Path(temp_model_dir)  # 是否微调,模型微调,临时模型目录
        )
        cache = LocalTrainingCache()  # 本地训练缓存
        trainer = GraphTrainer(model_storage, cache, DaskGraphRunner)  # Graph训练器

        if dry_run:  # dry运行
            fingerprint_status = trainer.fingerprint(                        # fingerprint状态
                model_configuration.train_schema, file_importer              # 模型配置的训练模式,文件导入器
            )
            return _dry_run_result(fingerprint_status, force_full_training)  # 返回dry运行结果

        model_name = _determine_model_name(fixed_model_name, training_type)  # 确定模型名称
        full_model_path = Path(output_path, model_name)                # 完整的模型路径

        with telemetry.track_model_training(                    # 跟踪模型训练
            file_importer, model_type=training_type.model_type  # 文件导入器,模型类型
        ):
            trainer.train(                               # 训练
                model_configuration,                     # 模型配置
                file_importer,                           # 文件导入器
                full_model_path,                         # 完整的模型路径
                force_retraining=force_full_training,    # 强制重新训练
                is_finetuning=is_finetuning,             # 是否微调
            )
            rasa.shared.utils.cli.print_success(         # 打印成功
                f"Your Rasa model is trained and saved at '{full_model_path}'."  # Rasa模型已经训练并保存在'{full_model_path}'。
            )

        return TrainingResult(str(full_model_path), 0)   # 训练结果

1.传递来的形参数据

2._train_graph()函数组成
  该函数主要由3个方法组成,如下所示:

  • model_configuration = recipe.graph_config_for_recipe(*)
  • trainer = GraphTrainer(model_storage, cache, DaskGraphRunner)
  • trainer.train(model_configuration, file_importer, full_model_path, force_retraining, is_finetuning)

二._train_graph()函数中的方法
1.file_importer.get_config()
  将config.yml文件转化为dict类型,如下所示:

2.Recipe.recipe_for_name(config.get("recipe"))

(1)ENTITY_EXTRACTOR = ComponentType.ENTITY_EXTRACTOR
实体抽取器。
(2)INTENT_CLASSIFIER = ComponentType.INTENT_CLASSIFIER
意图分类器。
(3)MESSAGE_FEATURIZER = ComponentType.MESSAGE_FEATURIZER
消息特征化。
(4)MESSAGE_TOKENIZER = ComponentType.MESSAGE_TOKENIZER
消息Tokenizer。
(5)MODEL_LOADER = ComponentType.MODEL_LOADER
模型加载器。
(6)POLICY_WITHOUT_END_TO_END_SUPPORT = ComponentType.POLICY_WITHOUT_END_TO_END_SUPPORT
非端到端策略支持。
(7)POLICY_WITH_END_TO_END_SUPPORT = ComponentType.POLICY_WITH_END_TO_END_SUPPORT
端到端策略支持。

3.model_configuration = recipe.graph_config_for_recipe(*)
  model_configuration.train_schema和model_configuration.predict_schema的数据类型都是GraphSchema类对象,分别表示在训练和预测时所需要的SchemaNode,以及SchemaNode在GraphSchema中的依赖关系。

(1)model_configuration.train_schema

  • schema_validator:rasa.graph_components.validators.default_recipe_validator.DefaultV1RecipeValidator类中的validate方法
  • finetuning_validator:rasa.graph_components.validators.finetuning_validator.FinetuningValidator类中的validate方法
  • nlu_training_data_provider:rasa.graph_components.providers.nlu_training_data_provider.NLUTrainingDataProvider类中的provide方法
  • train_JiebaTokenizer0:rasa.nlu.tokenizers.jieba_tokenizer.JiebaTokenizer类中的train方法
  • run_JiebaTokenizer0:rasa.nlu.tokenizers.jieba_tokenizer.JiebaTokenizer类中的process_training_data方法
  • run_LanguageModelFeaturizer1:rasa.nlu.featurizers.dense_featurizer.lm_featurizer.LanguageModelFeaturizer类中的process_training_data方法
  • train_DIETClassifier2:rasa.nlu.classifiers.diet_classifier.DIETClassifier类中的train方法
  • train_ResponseSelector3:rasa.nlu.selectors.response_selector.ResponseSelector类中的train方法

说明:ResponseSelector类继承自DIETClassifier类。

(2)model_configuration.predict_schema

  • nlu_message_converter:rasa.graph_components.converters.nlu_message_converter.NLUMessageConverter类中的convert_user_message方法
  • run_JiebaTokenizer0:rasa.nlu.tokenizers.jieba_tokenizer.JiebaTokenizer类中的process方法
  • run_LanguageModelFeaturizer1:rasa.nlu.featurizers.dense_featurizer.lm_featurizer.LanguageModelFeaturizer类中的process方法
  • run_DIETClassifier2:rasa.nlu.classifiers.diet_classifier.DIETClassifier类中的process方法
  • run_ResponseSelector3:rasa.nlu.selectors.response_selector.ResponseSelector类中的process方法
  • run_RegexMessageHandler:rasa.nlu.classifiers.regex_message_handler.RegexMessageHandler类中的process方法

4.tempdir_name
  'C:\Users\ADMINI~1\AppData\Local\Temp\tmpg0v179ea'

5.trainer = GraphTrainer(*)和trainer.train(*)
  这里执行的代码是rasa/engine/training/graph_trainer.py中GraphTrainer类的train()方法,实现功能为训练和打包模型并返回预测graph运行程序。

6.Rasa中GraphComponent的子类


参考文献:
[1]https://github.com/RasaHQ/rasa
[2]rasa 3.2.10 NLU模块的训练:https://zhuanlan.zhihu.com/p/574935615
[3]rasa.engine.graph:https://rasa.com/docs/rasa/next/reference/rasa/engine/graph/

标签:training,nlu,1.2,rasa,graph,train,model
From: https://www.cnblogs.com/shengshengwang/p/17826532.html

相关文章

  • rasa train nlu详解:1.1-train_nlu()函数
      本文使用《使用ResponseSelector实现校园招聘FAQ机器人》中的例子,主要详解介绍train_nlu()函数中变量的具体值。一.rasa/model_training.py/train_nlu()函数  train_nlu()函数实现,如下所示:deftrain_nlu(config:Text,nlu_data:Optional[Text],output:T......
  • train_logReg_param.o:train_logReg_param.cc:(.text+0x3407): more undefined refere
     001、make编译报错:train_logReg_param.o:train_logReg_param.cc:(.text+0x3407):moreundefinedreferencesto`std::__throw_out_of_range_fmt(charconst*,...)'follow 002、解决方法(可能是gcc版本的问题)a、gcc当前版本:(py38)[[email protected]]#gcc......
  • eTest_v1.1.23版本 支持拖拽功能
    引言随着互联网技术的快速发展,自动化测试已经成为确保软件质量的关键手段。eTest是一款强大的自动化测试工具,近期发布了eTest_v1.1.23版本,同时其搭配使用的Chrome插件也更新至Chrome插件v_1.3.5。这两个版本的更新为使用者带来了许多新的功能和优化,其中包括eTest的回放拖拽......
  • K8S基础:搭建K8S集群(v1.27.6)
    Kubernetes 是一个可移植、可扩展的开源平台,用于管理容器化的工作负载和服务,可促进声明式配置和自动化。Kubernetes拥有一个庞大且快速增长的生态,其服务、支持和工具的使用范围相当广泛。准备节点主机名IP系统&内核配置master01k8s0110.70.5.190Centos7.9,Kernel5.4.259-1.el7.el......
  • PreTrainedModel 中 from_pretrained 和 load_state_dict 的关联
    from_pretrainedAutoTokenizer.from_pretrained:从path/vocab.json中加载tokenizerAutoConfig.from_pretrained:从path/config.json中加载模型配置信息更新模型配置信息:model=Model(config)PreTrainedModel.from_pretrained:加载模型结构和模型参数load_checkpoint......
  • centos7.9安装kubernetes1.27.4版本
    ./etc/os-releasecolor(){RES_COL=60MOVE_TO_COL="echo-en\\033[${RES_COL}G"SETCOLOR_SUCCESS="echo-en\\033[1;32m"SETCOLOR_FAILURE="echo-en\\033[1;31m"SETCOLOR_WARNING="echo-en\\0......
  • 文心一言 VS 讯飞星火 VS chatgpt (131)-- 算法导论11.2 3题
    三、用go语言,Marley教授做了这样一个假设,即如果将链模式改动一下,使得每个链表都能保持已排好序的顺序,散列的性能就可以有较大的提高。Marley教授的改动对成功查找、不成功查找、插入和删除操作的运行时间有何影响?文心一言:Marley教授的假设是关于调整链表(可能是一个散列链表)以保......
  • Kubernetes(v1.21)简介
    Kubernetes(v1.21)简介第一章:背景1.1部署方式的演变传统部署时代:①在物理机服务器上运行应用程序。②无法为应用程序定义资源边界。③导致资源分配问题。如果在物理服务器上运行多个应用程序,则可能会出现一个应用程序占用大部分资源的情况下,会导致其它应用程序的......
  • 文心一言 VS 讯飞星火 VS chatgpt (131)-- 算法导论11.2 3题
    三、用go语言,Marley教授做了这样一个假设,即如果将链模式改动一下,使得每个链表都能保持已排好序的顺序,散列的性能就可以有较大的提高。Marley教授的改动对成功查找、不成功查找、插入和删除操作的运行时间有何影响?文心一言:Marley教授的假设是关于调整链表(可能是一个散列链表)......
  • openresty 1.21.4.3 发布
    实际上openresty比较早就发布openresty1.21.4.3版本了,只是周边的构建包,就在昨天才发布的,核心是解决了nginx关于http2处理的问题对于使用nginx以及openresty的同学比较推荐进行升级参考资料https://openresty.org/en/changelog-1021004.htmlhttps://github.com/openresty/......