首页 > 编程语言 >囚徒5.3_SSA_BP算法

囚徒5.3_SSA_BP算法

时间:2024-06-02 15:43:55浏览次数:26  
标签:5.3 fit self BP np weights model SSA size

麻雀算法加上bp网络


import numpy as np
import tensorflow as tf
from tensorflow.keras.datasets import mnist
from tensorflow.keras.utils import to_categorical
from tensorflow.keras.models import Model
from tensorflow.keras.layers import Input, Conv2D, MaxPooling2D, Flatten, Dense
from sklearn.metrics import accuracy_score

# 加载MNIST数据集
(X_train, y_train), (X_test, y_test) = mnist.load_data()
X_train = X_train.reshape(-1, 28, 28, 1).astype('float32') / 255.0
X_test = X_test.reshape(-1, 28, 28, 1).astype('float32') / 255.0
y_train = to_categorical(y_train, 10)
y_test = to_categorical(y_test, 10)

# 定义CNN模型
def create_model(weights=None):
    inputs = Input(shape=(28, 28, 1))
    x = Conv2D(32, kernel_size=(3, 3), activation='relu')(inputs)
    x = MaxPooling2D(pool_size=(2, 2))(x)
    x = Flatten()(x)
    x = Dense(128, activation='relu')(x)
    outputs = Dense(10, activation='softmax')(x)
    model = Model(inputs, outputs)
    if weights is not None:
        model.set_weights(weights)
    model.compile(optimizer='adam', loss='categorical_crossentropy', metrics=['accuracy'])
    return model

# 定义麻雀算法参数
class SSA:
    def __init__(self, pop_size, dim, max_iter, lb, ub):
        self.pop_size = pop_size
        self.dim = dim
        self.max_iter = max_iter
        self.lb = lb
        self.ub = ub
        self.X = np.random.uniform(low=lb, high=ub, size=(pop_size, dim))
        self.p_fit = np.full(pop_size, float('inf'))
        self.g_best = None
        self.g_best_fit = float('inf')

    def fitness(self, x):
        # 将个体解转换为CNN权重
        weights = self.decode_weights(x)
        model = create_model(weights)
        model.fit(X_train, y_train, epochs=1, batch_size=32, verbose=0)
        y_pred = model.predict(X_test)
        return -accuracy_score(np.argmax(y_test, axis=1), np.argmax(y_pred, axis=1))

    def decode_weights(self, x):
        # 解码权重,这里假设dim与CNN权重的总数匹配
        model = create_model()
        shapes = [w.shape for w in model.get_weights()]
        weights = []
        idx = 0
        for shape in shapes:
            size = np.prod(shape)
            weights.append(x[idx:idx + size].reshape(shape))
            idx += size
        return weights

    def update(self):
        for t in range(self.max_iter):
            for i in range(self.pop_size):
                fit = self.fitness(self.X[i])
                if fit < self.p_fit[i]:
                    self.p_fit[i] = fit
                    if fit < self.g_best_fit:
                        self.g_best = self.X[i].copy()
                        self.g_best_fit = fit
            for i in range(self.pop_size):
                r1 = np.random.rand()
                r2 = np.random.rand()
                if r2 < 0.8:
                    self.X[i] = self.X[i] + r1 * (self.g_best - self.X[i])
                else:
                    self.X[i] = self.X[i] + r1 * (self.X[i] - self.g_best)
                self.X[i] = np.clip(self.X[i], self.lb, self.ub)

# 初始化麻雀算法
num_weights = sum([np.prod(w.shape) for w in create_model().get_weights()])
ssa = SSA(pop_size=10, dim=num_weights, max_iter=10, lb=-1, ub=1)
ssa.update()

# 输出最优解并评估
optimal_weights = ssa.decode_weights(ssa.g_best)
model = create_model(optimal_weights)
model.fit(X_train, y_train, epochs=5, batch_size=32, verbose=1)
loss, accuracy = model.evaluate(X_test, y_test, verbose=0)
print("Test accuracy:", accuracy)

标签:5.3,fit,self,BP,np,weights,model,SSA,size
From: https://www.cnblogs.com/qt-pyq/p/18227199

相关文章

  • FFmpeg开发笔记(二十五)Linux环境给FFmpeg集成libwebp
    ​《FFmpeg开发实战:从零基础到短视频上线》一书介绍了JPEG、PNG、GIF等图片格式,以及如何通过FFmpeg把视频画面转存为这些格式。除了上述这些常见的图片格式,还有较新的WebP格式,它由VP8视频标准派生而来,VP8演进的视频格式叫做WebM,图片格式则叫WebP。若想让FFmpeg支持WebP图片的编......
  • 1day公开用友PLM-MessageService信息泄漏漏洞
       0x01阅读须知        技术文章仅供参考,此文所提供的信息只为网络安全人员对自己所负责的网站、服务器等(包括但不限于)进行检测或维护参考,未经授权请勿利用文章中的技术资料对任何计算机系统进行入侵操作。利用此文所提供的信息而造成的直接或间接后果和损失,均由......
  • [2024.5.31晚~2024.6.1早鲜花] 余生的第一天
    [2024.5.31晚~2024.6.1早鲜花]余生的第一天来\(GF\)集训一两周了,宿舍居然有电梯,而且学生居然可以乘坐,\(GF\)的饭也十分好吃,比\(XF\)的好吃一万倍,听\(yzj\)说清华附的比\(GF\)好吃一万倍,难以想象了认识了好多别的学校的女生!大家都好可爱(●'◡'●),传奇的原神传教大师\(cyl\)有......
  • 5.31 CF R 949 (Div.2)
    5.31CFR949(Div.2)Solve:A~D(4/6)Rank:99Rating:\(1939+131=2070\)(\(1989+81=2070\))发挥评价:Normal失误:小失误是做2B时候没有注意,第一次错了之后就急了,接连交了\(4\)发罚时。注意如果交上去WA了,想清楚、找清楚问题再交。CF1981E(me*2200)给定\(n\)......
  • 2024.5.31 做题记录
    1.外星千足虫公元\(2333\)年\(2\)月\(3\)日,在经历了\(17\)年零\(3\)个月的漫长旅行后,“格纳格鲁一号”载人火箭返回舱终于安全着陆。此枚火箭由美国国家航空航天局(NASA)研制发射,行经火星、金星、土卫六、木卫二、谷神星、“张衡星”等\(23\)颗太阳系星球,并最终在小行......
  • python 通过 subprocess 运行的代码 exit(1) 不能使得pipeline fail
    在使用Python的subprocess模块运行外部命令时,如果你希望子进程的退出状态码能够影响Python脚本的执行结果,尤其是在使用管道(pipeline)时,你需要手动检查子进程的返回码并采取相应的措施。简单地使用subprocess.run或subprocess.call运行子进程并不会自动使Python脚......
  • 《旋转的快速傅里叶变换》——2024.5.31
    $$\aleph$$——发疯记录(无题,不知道起什么好,用前几天看书看到的符号阿列夫表示了)我很久没发过阶段性总结类的博文了,对比去年来是少之又少。一是因为我觉得现在的日子比去年枯燥多了;二是其实我平时会写记录,但没有总的;三是上了高中以后几次语文考试我的作文成绩都很差,老师说我写的......
  • SwiftUI中SafeArea的管理与使用(ignoresSafeArea, safeAreaPadding, safeAreaInset)
    SafeArea是指不与视图控制器提供的导航栏、选项卡栏、工具栏或其他视图重叠的内容空间。在UIKit中,开发人员需要使用safeAreaInsets或safeAreaLayoutGuide来确保视图被放置在界面的可见部分。SwiftUI彻底简化了上述过程,除非开发者明确要求视图突破安全区域的限制,否则SwiftU......
  • 制造行业对BPMN技术应用
    1. 制造行业关键特点        制造行业是一个重要的经济领域,具有以下关键特点:生产规模大:制造行业通常涉及大量的生产设备和工人,生产规模较大,需要高效的生产管理系统来确保生产顺利进行。​生产过程复杂:制造行业的生产过程通常涉及多个环节,包括原材料采购、生产加工......
  • Failed to resolve org.junit.jupiter:junit-jupiter-engine:5.3.1
    跟着尚硅谷学SSM测试这块老不对就没按视频里的用的junit4平替了今天的@BeforeEach平替不了搜了半天找不到解决方法就把报错的都整到依赖里了然后就好了也不知道具体咋回事<dependency><groupId>org.junit.jupiter</groupId><artifactId>junit......