init.py
我在学init用法时候的笔记 该文件里面都是导入模块,其中从dataset.py是导入所有模块,因为模糊导入的__all__没有定义。
from .modules import BayerDemosaick
from .modules import XTransDemosaick
from .mosaic import xtrans
from .mosaic import bayer
from .mosaic import xtrans_cell
from .dataset import *dataset.py文件,这个文件里面很多代码和download_dataset.py里面一样[download_dataset.py的阅读笔记]
“”“Dataset loader for demosaicnet.”""
import os
import platform
import subprocess#用来复制和删除文件的模块
import shutil#hash值,md5值
import hashlibimport numpy as np
#图像读写模块
from imageio import imread
from torch.utils.data import Dataset as TorchDataset
import wget#作者写的一个模块好像,但是在torch的包下面有
ttools的资料ttools的阅读笔记
import ttools#文件夹下面的包里面的模块
from .mosaic import bayer, xtrans
python list 或者tensor 代码后最后 加了一个逗号(,)就变成了元组。为什么Python在列表和元组的末尾允许使用逗号?
__all__的用法 因为在init里面是from .dataset import *,所以导入的是all列表里面的这些东西
all = [“BAYER_MODE”, “XTRANS_MODE”, “Dataset”,
“TRAIN_SUBSET”, “VAL_SUBSET”, “TEST_SUBSET”]
#输入ttools找到它的安装路径,在pycharm中打开,用全局查找找到定义处
pycharm全局搜索方法
def get_logger(name):
"""Get a named logger.
Args:
name(string): name of the logger
"""
return logging.getLogger(name)
#私有属性__name__ 就是当前模块名
LOG = ttools.get_logger(name)
def set_logger(debug=False):
"""Set the default logging level and log format.
Args:
debug(bool): if True, enable debug logs.
"""
log_level = logging.INFO
prefix = "[%(process)d] %(levelname)s %(name)s"
suffix = " | %(message)s"
if debug:
log_level = logging.DEBUG
prefix += " %(filename)s:%(lineno)s"
if HAS_COLORED_LOGS:
coloredlogs.install(
level=log_level,
format=prefix+suffix)
else:
logging.basicConfig(
level=log_level,
format=prefix+suffix)
#设置logger属性
ttools.set_logger(True)#设置几个量
BAYER_MODE = “bayer”
“”“Applies a Bayer mosaic pattern.”""XTRANS_MODE = “xtrans”
“”“Applies an X-Trans mosaic pattern.”""TRAIN_SUBSET = “train”
“”“Loads the ‘train’ subset of the data.”""VAL_SUBSET = “val”
“”“Loads the ‘val’ subset of the data.”""TEST_SUBSET = “test”
“”“Loads the ‘test’ subset of the data.”""class Dataset(TorchDataset):
“”"Dataset of challenging image patches for demosaicking.
Args:
download(bool): if True, automatically download the dataset.
mode(:class:`BAYER_MODE` or :class:`XTRANS_MODE`): mosaic pattern to apply to the data.
subset(:class:`TRAIN_SUBET`, :class:`VAL_SUBSET` or :class:`TEST_SUBSET`): subset of the data to load.
"""
def __init__(self, root, download=False,
mode=BAYER_MODE, subset="train"):
super(Dataset, self).__init__()
self.root = os.path.abspath(root)
if subset not in [TRAIN_SUBSET, VAL_SUBSET, TEST_SUBSET]:
raise ValueError("Dataset subet should be '%s', '%s' or '%s', got"
" %s" % (TRAIN_SUBSET, TEST_SUBSET, VAL_SUBSET,
subset))
if mode not in [BAYER_MODE, XTRANS_MODE]:
raise ValueError("Dataset mode should be '%s' or '%s', got"
" %s" % (BAYER_MODE, XTRANS_MODE, mode))
self.mode = mode
listfile = os.path.join(self.root, subset, "filelist.txt")
LOG.debug("Reading image list from %s", listfile)
if not os.path.exists(listfile):
if download:
_download(self.root)
else:
LOG.error("Filelist %s not found", listfile)
raise ValueError("Filelist %s not found" % listfile)
else:
LOG.debug("No need no download the data, filelist exists.")
self.files = []
with open(listfile, "r") as fid:
for fname in fid.readlines():
self.files.append(os.path.join(self.root, subset, fname.strip()))
def __len__(self):
return len(self.files)
def __getitem__(self, idx):
"""Fetches a mosaic / demosaicked pair of images.
Returns
mosaic(np.array): with size [3, h, w] the mosaic data with separated color channels.
img(np.array): with size [3, h, w] the groundtruth image.
"""
fname = self.files[idx]
img = np.array(imread(fname)).astype(np.float32) / (2**8-1)
img = np.transpose(img, [2, 0, 1])
if self.mode == BAYER_MODE:
mosaic = bayer(img)
else:
mosaic = xtrans(img)
return mosaic, img
CHECKSUMS = {
‘datasets.z01’: ‘da46277afe85d3a91c065e4751fb8175’,
‘datasets.zip’: ‘3434f60f5e9b263ef78e207b54e9debe’,
}def _download(dst):
dst = os.path.abspath(dst)
files = CHECKSUMS.keys()
fullzip = os.path.join(dst, “datasets.zip”)
joinedzip = os.path.join(dst, “joined.zip”)
URL_ROOT = "https://data.csail.mit.edu/graphics/demosaicnet"
if not os.path.exists(joinedzip):
LOG.info("Dowloading %d files to %s (This will take a while, and ~80GB)", len(
files), dst)
os.makedirs(dst, exist_ok=True)
for f in files:
fname = os.path.join(dst, f)
url = os.path.join(URL_ROOT, f)
do_download = True
if os.path.exists(fname):
checksum = md5sum(fname)
if checksum == CHECKSUMS[f]: # File is is and correct
LOG.info('%s already downloaded, with correct checksum', f)
do_download = False
else:
LOG.warning('%s checksums do not match, got %s, should be %s',
f, checksum, CHECKSUMS[f])
try:
os.remove(fname)
except OSError as e:
LOG.error("Could not delete broken part %s: %s", f, e)
raise ValueError
if do_download:
LOG.info('Downloading %s', f)
wget.download(url, fname)
checksum = md5sum(fname)
if checksum == CHECKSUMS[f]:
LOG.info("%s MD5 correct", f)
else:
LOG.error('%s checksums do not match, got %s, should be %s. Downloading failed',
f, checksum, CHECKSUMS[f])
LOG.info("Joining zip files")
cmd = " ".join(["zip", "-FF", fullzip, "--out", joinedzip])
subprocess.check_call(cmd, shell=True)
# Cleanup the parts
for f in files:
fname = os.path.join(dst, f)
try:
os.remove(fname)
except OSError as e:
LOG.warning("Could not delete file %s", f)
# Extract
wd = os.path.abspath(os.curdir)
os.chdir(dst)
LOG.info("Extracting files from %s", joinedzip)
cmd = " ".join(["unzip", joinedzip])
subprocess.check_call(cmd, shell=True)
try:
os.remove(joinedzip)
except OSError as e:
LOG.warning("Could not delete file %s", f)
LOG.info("Moving subfolders")
for k in ["train", "test", "val"]:
shutil.move(os.path.join(dst, "images", k), os.path.join(dst, k))
images = os.path.join(dst, "images")
LOG.info("removing '%s' folder", images)
shutil.rmtree(images)
def md5sum(filename, blocksize=65536):
hash = hashlib.md5()
with open(filename, “rb”) as f:
for block in iter(lambda: f.read(blocksize), b""):
hash.update(block)
return hash.hexdigest()