1 def compress(self, x): 2 y = self.g_a(x) 3 y_strings = self.entropy_bottleneck.compress(y) 4 return {"strings": [y_strings], "shape": y.size()[-2:]} 5 6 def decompress(self, strings, shape): 7 assert isinstance(strings, list) and len(strings) == 1 8 y_hat = self.entropy_bottleneck.decompress(strings[0], shape) 9 x_hat = self.g_s(y_hat).clamp_(0, 1) 10 return {"x_hat": x_hat}
class EntropyBottleneck(EntropyModel): def compress(self, x): indexes = self._build_indexes(x.size()) medians = self._get_medians().detach() spatial_dims = len(x.size()) - 2 medians = self._extend_ndims(medians, spatial_dims) medians = medians.expand(x.size(0), *([-1] * (spatial_dims + 1))) return super().compress(x, indexes, medians) def decompress(self, strings, size): output_size = (len(strings), self._quantized_cdf.size(0), *size) indexes = self._build_indexes(output_size).to(self._quantized_cdf.device) medians = self._extend_ndims(self._get_medians().detach(), len(size)) medians = medians.expand(len(strings), *([-1] * (len(size) + 1))) return super().decompress(strings, indexes, medians.dtype, medians)
class _EntropyCoder:
底层代码中常用的是非对称数系编码和区间编码 然后使用index进行编码/解码 将概率质量函数转换为量化的累积分布函数,并定义了一个占位符方法,鼓励在子类中提供具体实现。class EntropyModel(nn.Module):
标签:编码,实现,self,indexes,medians,._,strings,size From: https://www.cnblogs.com/CLGYPYJ/p/17870805.html