获得tensorflow模型的参数量

   tensorflow模型训练好之后,通常会保存为.ckpt文件,有时我们想了解一下模型保存了多少参数,这固然可以手动计算,但是速度太慢,这里我写了一个程序可以直接获得模型的参数量,下面是源代码:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
from tensorflow.python import pywrap_tensorflow
import os
import numpy as np
model_dir = "models_pretrained/"
checkpoint_path = os.path.join(model_dir, "model.ckpt-82798")
reader = pywrap_tensorflow.NewCheckpointReader(checkpoint_path)
var_to_shape_map = reader.get_variable_to_shape_map()
total_parameters = 0
for key in var_to_shape_map:#list the keys of the model
# print(key)
# print(reader.get_tensor(key))
shape = np.shape(reader.get_tensor(key)) #get the shape of the tensor in the model
shape = list(shape)
# print(shape)
# print(len(shape))
variable_parameters = 1
for dim in shape:
# print(dim)
variable_parameters *= dim
# print(variable_parameters)
total_parameters += variable_parameters

print(total_parameters)

   这段代码很好理解,首先读取模型(reader= pywrap_tensorflow.NewCheckpointReader(checkpoint_path)),然后对每个key进行遍历,将每个key的参数量统计出来(具体通过得到其shape,将各维度进行相乘相加),然后累加这些key的参数量即可。

0%