首页 > 编程问答 >是否可以限制 scikit learn 模型仅预测某些标签?

是否可以限制 scikit learn 模型仅预测某些标签?

时间:2024-07-24 12:29:55浏览次数:12  
标签:python scikit-learn

我有两个模型在多个标签上进行了训练,并用它来预测游戏的类型。我注意到,由于模型经过训练,有时相同的输入数据可能会让两个模型输出截然不同的流派。

我想将预测限制为另一个模型建议的内容,但不知道该怎么做。下面的示例

Model1_labels = ["JRPG", "Horror", "FPS", "Platformer"]
Model2_Labels = ["Mario", "War_shooter", "fantasy_rpg"]

training_data = Label1      Label2            Tags
               JRPG        fantasy_rpg       open_world, action, level-up, fantasy
               JRPG        fantasy_rpg       level-up, turn-based, fantasy
               FPS         War_shooter       open-world, 1st person, tanks, planes, shooter
               FPS         War_shooter       1st person, war, shooter, level-up
               JRPG        Mario             level-up, turn-based, shooter
               ...

从示例中,War_shooter 只能是 FPS,因为战争射击游戏的描述是在战争期间设置的 FPS 游戏。

但如何限制?

关于我如何训练和预测的代码以下: SDG_PARAMS_DICT:最终[Dict[str,Any]] = dict(alpha = 1e-5,penalty =“l2”,max_iter = 1000,loss =“log_loss”) VECTORIZER_PARAMS_DICT: Final[Dict[str, Any]] = dict(ngram_range=(1, 4), min_df=5, max_df=0.8)

def build_model(x_data, y_data) -> Pipeline:
    game_predict_pipeline = Pipeline(
        [
            ("vect", CountVectorizer(VECTORIZER_PARAMS_DICT)),
            ("tfidf", TfidfTransformer()),
            ("clf", SelfTrainingClassifier(SGDClassifier(**SDG_PARAMS_DICT), verbose=True)),
        ]
    )
    X_train, X_test, y_train, y_test = train_test_split(x_data,
                                                        y_data,
                                                        train_size=0.3)
     game_predict_pipeline.fit(X_train, y_train)

     return game_predict_pipeline 

 game_data = pd.read_excel("c:/my_game_data.xlsx", keep_default_na=False)

 model1 = build_model(game_data["Tag"], game_data["Label1"])
 model2 = build_model(game_data["Tag"], game_data["Label2"])

 test_tags = "level-up, open-world, shooter"

 model1.predict(test_tag)
 model2.predict(test_tag)

结果

model1 - correct
   FPS
model2 - incorrect
   Mario

我考虑过执行 Predict_proba 并从列表中删除标签,但他的不会改变预测的概率,导致许多分数未达到理论截止值

comparison_dict: Dict = {"FPS":["War_shooter"]}


prediction2: np.ndarray = model2.predict_proba(test_tag)
classes: np.ndarray = model2.classes_
prediction_dict: Dict = {}
for idx, model_cls in enumerate(classes):
    if model_cls in comparison_dict.get(model1_precition):
        if prediction2[0][idx] < 0.6: # cutoff
              prediction_dict[model_cls] = prediction2[0][idx]

结果

output predict_proba
    Mario      : 0.43
    War_shooter: 0.42
    fantasy_rpg: 0.15

output
    None

output without cutoff
    "War_shooter": 0.42

遇到的问题是的两个模型独立运作,无法利用彼此的预测来完善结果。可以通过几种方法来限制模型的预测并合并它们的知识:

1. 基于规则的后期处理:

  • 这是提出的解决方案的一种变体。在的后处理步骤中,不是使用任意的截止值,而是实现明确的规则,基于一个模型的预测来过滤另一个模型的预测。
  • 例如,可以创建一个规则字典,例如:

    python restriction_rules = { "FPS": ["War_shooter"], "fantasy_rpg": ["JRPG"], }

  • 然后,可以修改预测逻辑如下:

    ```python model1_prediction = model1.predict(test_tag)[0] model2_prediction = model2.predict(test_tag)[0]

    allowed_predictions = restriction_rules.get(model1_prediction, [])

    if model2_prediction in allowed_predictions: final_prediction = model2_prediction else: # 使用默认值、请求用户输入或跳过预测 final_prediction = "Unknown" print(final_prediction) ```

2. 层叠分类器:

  • 在此方法中,可以训练第三个模型(“元分类器”),该模型将前两个模型的预测作为输入。
  • 训练此元分类器以根据初始预测输出最终预测。
  • 这种方法允许学习更复杂的规则,而不仅仅是简单的过滤。

    ```python from sklearn.ensemble import RandomForestClassifier

    # Get predictions from the first two models model1_predictions = model1.predict(game_data["Tag"]) model2_predictions = model2.predict(game_data["Tag"])

    # Combine predictions into a new feature matrix combined_predictions = np.column_stack((model1_predictions, model2_predictions))

    # Train the meta-classifier meta_classifier = RandomForestClassifier() meta_classifier.fit(combined_predictions, game_data["Label2"])

    # For prediction: test_prediction1 = model1.predict(test_tag) test_prediction2 = model2.predict(test_tag) combined_test_prediction = np.array([test_prediction1, test_prediction2]).reshape(1, -1) final_prediction = meta_classifier.predict(combined_test_prediction) ```

3. 结构化预测:

  • 如果的标签之间存在固有的依赖关系(如的示例中所示),则可以探索结构化分类方法,例如条件随机场 (CRF)。
  • CRF 可以将不同标签之间的依赖关系考虑在内,并产生更一致的预测。
  • 但是,与简单的分类器相比,CRF 的设置和调整更复杂。

4. 调整预测概率阈值:

  • 可以调整模型 2 的 predict_proba 方法中的概率阈值,而不是完全删除标签。
  • 这将允许控制模型在分配标签之前的置信度级别。

    python threshold = 0.5 probabilities = model2.predict_proba(test_tag)[0] filtered_predictions = [model2.classes_[i] for i, prob in enumerate(probabilities) if prob >= threshold]

选择哪种方法取决于数据的特定特征以及需要达到的复杂程度。基于规则的后处理是最容易实现的,而层叠分类器和结构化预测提供了更高的灵活性,但也需要更多数据和调整。

标签:python,scikit-learn
From: 78786374

相关文章

  • 【Dison夏令营 Day 28】用 Python 创建恐龙游戏
    谁没有玩过谷歌著名的“恐龙游戏”?也许每个人都玩过这个游戏。今天,在这篇文章中,我们将帮助你用Python开发一个恐龙游戏。本教程将深入讲解每一行代码,并提供参考资料。我们将尽力让读者详细、透彻地理解这个项目。Python版恐龙游戏的任务记录包括图片文档和Python资料......
  • Python 无法 pickle 自定义类型
    我正在尝试在ProcessPool中运行一个函数,该函数将通过读取python文件并运行生成的类中的方法来加载一些自定义类。我遇到的错误是TypeError:cannotpickle'generator'object该方法需要返回一个生成器。我该如何解决这个问题,谢谢。我用谷歌搜索但没有运气。......
  • python 语法无效?
    我试图编写一些Python代码,但由于某些奇怪的原因,它重复了无效的语法,我不知道最大的问题是什么。这些行是文件中唯一的代码行。Age=int(input("Howoldareyou?:"))ifAge>=18:print("YouareaAdult!")我尝试更改行,因为这似乎是我的生气,但它没有做任何帮......
  • 如何在Python的matplotlib中将条形标签绘制到右侧并为条形标签添加标题?
    我已经在python中的matplotlib中创建了一个图表,但是以下代码中的最后一行不允许在图表之外对齐条形标签。importmatplotlib.pyplotaspltg=df.plot.barh(x=name,y=days)g.set_title("Dayspeopleshowedup")g.bar_label(g.containers[0],label_type='edge')我得......
  • 19、Python之容器:快来数一数,24678?Counter能数得更好
    引言关于数据的分组计数,前面的文章中已经涉及了很多次。眼下要进行分组计数,我们可用的方法有:1、直接使用dict进行计数,需要对首次出现的键进行判断初始化的操作;2、使用dict的setdefault()方法进行计数,代码可以简化一些,虽然方法名有点怪;3、defaultdict进行计数,可以设置自动......
  • 如何使用 C# 检查用户是否安装了最低 Python 版本并且可以访问我的代码?
    我正在开发一个C#程序,该程序必须为一项特定任务运行一些Python代码。(Python代码很复杂,是由另一个团队开发的。无法在C#中重现其功能。)我正在尝试更新我的程序的安装程序文件以解决此问题:我希望它检查用户是否(谁正在安装我的程序)已安装Python并且它满足我的最低版......
  • 如何优雅地将复杂的Python对象和SQLAlchemy对象模型类结合起来?
    我有一个相当复杂的类,具有从提供的df到init计算的复杂属性,这些属性可能是最终可以序列化为字符串的其他类类型。在Python中,我想处理对象而不是原始类型,但也想使用SQLAlchemy与数据库交互。表中的列与许多类属性相同,如何优雅地组合这两个类?我可以使用组合并将数据......
  • Python Match Case:检查未知长度的可迭代内部的类型
    我想使用匹配大小写检查一个未知长度的迭代(假设为list)仅包含给定类型(假设为float)(还有其他情况,只有这个给我带来了问题)。case[*elems]ifall([isinstance(elem,float)foreleminelems]):returnnum这个似乎可行,但确实很不Pythony。看来应该有更简单的方法。......
  • Python实现excel数据的读取和写入
    1.安装说到前面的话,实现excel文件数据的读取和写入,在python中还有其它方法,比如说pandas。鉴于最近粉丝朋友问到上面的“xlrd”和“xlwt”,那么笔者下面将通过这两个方法,来实现excel文件数据的读取和写入。首先,我们先需要提前安装好对应的库。需要注意的是,xlrd从2.0版本开始,只......
  • python_进程与线程_多线程
    一、程序与进程的概念1、进程:指启动后的程序,系统会为进程分配内存空间二、创建进程的方式1、第一种创建进程的方式process(group=None,target,name,args,kwargs)group:表示分组,实际上不使用,默认为None即可target:表示子进程要执行的任务,支持函数名name:表示子进程的......