首页 > 其他分享 >get_layer_and_variable_from tf.keras.Model

get_layer_and_variable_from tf.keras.Model

时间:2022-09-19 11:03:05浏览次数:57  
标签:layer name get keras dict tf model

def get_layers_and_variables_from_model(model: tf.keras.Model, scope_name=None):
    layer_dict = {}
    if scope_name is not None:
        base_name = scope_name
    else:
        base_name = model.name

    # get Layers
    for layer in model.layers:
        if isinstance(layer, tf.keras.Model):
            sub_model_layer_dict = get_layers_and_variables_from_model(
                layer, "{}/{}".format(base_name, layer.name)
            )
            layer_dict.update(sub_model_layer_dict)
        elif isinstance(layer, tf.keras.layers.Layer):
            layer_dict["{}/{}".format(base_name, layer.name)] = layer

    # get Variables
    for attr_name, attr_value in model.__dict__.items():
        # NOTE: _train_counter, _test_counter, _predict_counter are
        # built-in variables of tf.keras.Model
        if attr_name not in [
            "_train_counter",
            "_test_counter",
            "_predict_counter",
        ] and isinstance(attr_value, tf.Variable):
            layer_dict["{}/{}".format(base_name, attr_value.name)] = attr_value
    return layer_dict

  

递归:因为keras.Model中可能包含另一个keras.Model

版本:tf2.5.1

标签:layer,name,get,keras,dict,tf,model
From: https://www.cnblogs.com/deepllz/p/16601117.html

相关文章