1. 准备数据
1.1 预训练数据
目前看来数据部分是两类数据,一类是预训练数据
需要下载对应的图像
这部分数据是过滤版的CC-3M,上面图片量是59w,是300w图像的过滤版本
结构
{
"id": "GCC_train_002582585",
"image": "GCC_train_002582585.jpg",
"caption": "olive oil is a healthy ingredient used liberally .",
"blip_caption": "some olive oil and salt in a glass bottle",
"url": "https://i.pinimg.com/736x/de/13/3a/de133a155c777a9db265bb3e7888719d--colombo-olive-oils.jpg"
}
2.下载LLaVA权重
这里是说需要预留这些空间,因为一个9B的模型需要12.5G空间,这里会下载两个模型(12.5+12.5=25),并且合并模型所需的参数(25+12.5=37.5),看代码
base_model是llama模型,delta是llava_llama模型。可以看到词表大小是不一样的,比llama多4
然后上面步骤不是会生成相应的模型吗,这个就是加了权重后的模型,然后我们需要把里面的这三部分参数单独拎出来
# {'pytorch_model-00001-of-00002.bin': ['model.embed_tokens.weight'], '
# pytorch_model-00002-of-00002.bin': ['model.mm_projector.bias', 'model.mm_projector.weight']})
def parse_args():
parser = argparse.ArgumentParser(description='Extract MMProjector weights')
parser.add_argument('--model_name_or_path', type=str, help='model folder')
parser.add_argument('--output', type=str, help='output file')
args = parser.parse_args()
return args
if __name__ == '__main__':
args = parse_args()
model_indices = json.load(open(os.path.join(args.model_name_or_path, 'pytorch_model.bin.index.json')))
keys_to_match = ['mm_projector', 'embed_tokens', 'transformer.wte']
ckpt_to_key = defaultdict(list)
for k, v in model_indices['weight_map'].items():
#model.layers.21.mlp.up_proj.weight pytorch_model - 00001 - of - 00002.bin
#model.layers.21.post_attention_layernorm.weight pytorch_model - 00001 - of - 00002.bin
if any(key_match in k for key_match in keys_to_match):
ckpt_to_key[v].append(k)
loaded_weights = {}
#defaultdict(<class 'list'>,
# {'pytorch_model-00001-of-00002.bin': ['model.embed_tokens.weight'], '
# pytorch_model-00002-of-00002.bin': ['model.mm_projector.bias', 'model.mm_projector.weight']})
for ckpt_name, weight_keys in ckpt_to_key.items():
#pytorch_model-00001-of-00002.bin
#['model.embed_tokens.weight']
ckpt = torch.load(os.path.join(args.model_name_or_path, ckpt_name), map_location='cpu')
for k in weight_keys:
loaded_weights[k] = ckpt[k]
#单取这几个参数,组成
# {'pytorch_model-00001-of-00002.bin': ['model.embed_tokens.weight'], '
# pytorch_model-00002-of-00002.bin': ['model.mm_projector.bias', 'model.mm_projector.weight']})
torch.save(loaded_weights, '/data1/xuxing/hjb/LLaVA-main/checkpoints/llava-13b-pretrain.bin')