暗通道去雾具体代码

发布于 2021-06-09  70 次阅读


湖南大学《数字图像处理》课程设计之暗通道去雾

基于亮通道和暗通道相结合的图像去雾算法

  基于暗通道先验原理的图像去雾算法具有较好的去雾效果, 但是对于天空区域或者白色物体地方容易出现色彩偏差和过增强现象。针对这一问题, 提出基于亮通道和暗通道结合的先验图像去雾算法, 采用加权引导滤波估计透射率的算法平滑传输图, 显示较多的细节, 解决色彩偏差和过增强问题。本文提出的基于亮通道和暗通道结合的先验图像去雾算法的流程如图2所示, 具体步骤如下:

  • 1) 根据亮通道和暗通道先验原理, 分别提取亮通道图像和暗通道图像;

  • 2) 将亮通道图像像素强度接近无雾图像大气光, 与暗通道图像求取符合条件像素点平均值的方法相结合, 得到大气光;

  • 3) 借助于大气光散射模型和大气光, 估测透射率;

  • 4) 采用加权引导滤波算法优化透射率, 平滑并细化图像;

  • 5) 通过优化估测透射率、大气光和大气光散射模型, 对输入图像进行去雾处理, 获取最终的无雾图像。

直方图均衡化与暗通道处理对比

  直方图均衡算法对于整幅图像的像素使用相同的直方图变换,对于那些像素值分布比较均衡的图像来说,算法的效果很好。然后,如果图像中包括明显比图像其它区域暗或者亮的部分,在这些部分的对比度将得不到有效的增强。而且部分细节信息失真严重。

  暗通道处理的显著特点是,有雾图像或去雾不完全的图像,在暗通道的图像仍然偏白,而自然图像与已去雾区域的暗通道是非常黑、非常暗的。这也是该通道被称为暗通道的原因。但暗通道图存在一定程度的像素化,这导致了边缘位置的不连续性,如北京的边缘位置残留的雾区域

  暗通道具有一定的参数依赖(受窗口大小等参数影响)(也被称之为先验性)以及在天空、柏油马路等区域(这两种情况下虽然没有雾,但局部区域RGB最小值仍然会很大,导致计算所得的暗通道值较大,被错误地判别为有雾区域进行去雾。

具体代码

MAIN.PY

import getopt
import os
import sys
from utility import darkChannel
from utility import equalizeHist
from utility import statistics

if __name__ == '__main__':
    print("""
🅲🅾🅳🅴 🅱🆈 🅴🅳🅼🆄🅽🅳 🆉🅷🅰🅾
🅶🅸🆃🅷🆄🅱:🅶🅸🆃🅷🆄🅱/🅴🅳🅼🆄🅽🅳-🆉🅷🅰🅾
    """)
    argv = sys.argv[1:]
    modelFlag = None
    path = None
    try:
        opts, args = getopt.getopt(argv, "i:m:h",["input=","model=","help"])  # 短选项模式
    except:
        print("Panic: Please Using -h or --help")
        sys.exit(3)

    for opt, arg in opts:
        if opt in ['-i','--input']:
            path = arg
        elif opt in ['-m','--model']:
            modelFlag = arg
        elif opt in ['-h','--help']:
            print("""
            -i,--input: 输入文件的相对路径
            -m,--model: equalizeHist 为全局直方图模式
                        darkChannel  为全局暗通道去雾模式
            """)
            sys.exit(4)
    if path is None:
        print("Please retry to input correct filePath!")
        sys.exit(3)
    if modelFlag == "equalizeHist":
        M = equalizeHist.EqualizeHist(path)
        M.runEqualizeHist()
    elif modelFlag == "darkChannel":
        M = darkChannel.Model(path)
        M.getRecoverScene()
    else:
        print("Please retry to input correct args!")

darkChannel.py

import time

import cv2
import numpy as np
import math
from utility import log
from utility import readImg


class Node():
    def __init__(self, x, y, value):
        self.x = x
        self.y = y
        self.value = value

    def printInfo(self):
        print('{}{}{}'.format(self.x, self.y, self.value))


class Model(readImg.Reader):
    def __init__(self, imgPath):
        super().__init__(imgPath)
        self.reading()

    def getMinChannel(self,img):

        # 输入检查
        if len(img.shape) == 3 and img.shape[2] == 3:
            pass
        else:
            print("bad image shape, input must be color image")
            return None

        imgGray = np.zeros((img.shape[0], img.shape[1]), dtype=np.uint8)
        localMin = 255

        for i in range(0, img.shape[0]):
            for j in range(0, img.shape[1]):
                localMin = 255
                for k in range(0, 3):
                    if img.item((i, j, k)) < localMin:
                        localMin = img.item((i, j, k))
                imgGray[i, j] = localMin

        return imgGray

    # 获取暗通道
    def getDarkChannel(self,img, blockSize=3):

        # 输入检查
        if len(img.shape) == 2:
            pass
        else:
            print("bad image shape, input image must be two demensions")
            return None

        # blockSize检查
        if blockSize % 2 == 0 or blockSize < 3:
            print('blockSize is not odd or too small')
            return None

        # 计算addSize
        addSize = (blockSize - 1) // 2

        newHeight = img.shape[0] + blockSize - 1
        newWidth = img.shape[1] + blockSize - 1

        # 中间结果
        imgMiddle = np.zeros((newHeight, newWidth))
        imgMiddle[:, :] = 255

        imgMiddle[addSize:newHeight - addSize, addSize:newWidth - addSize] = img

        imgDark = np.zeros((img.shape[0], img.shape[1]), np.uint8)
        localMin = 255

        for i in range(addSize, newHeight - addSize):
            for j in range(addSize, newWidth - addSize):
                localMin = 255
                for k in range(i - addSize, i + addSize + 1):
                    for l in range(j - addSize, j + addSize + 1):
                        if imgMiddle.item((k, l)) < localMin:
                            localMin = imgMiddle.item((k, l))
                imgDark[i - addSize, j - addSize] = localMin

        return imgDark

    # 获取全局大气光强度
    def getAtomsphericLight(self,darkChannel,img, meanMode=False, percent=0.001):

        size = darkChannel.shape[0] * darkChannel.shape[1]
        height = darkChannel.shape[0]
        width = darkChannel.shape[1]

        nodes = []

        # 用一个链表结构(list)存储数据
        for i in range(0, height):
            for j in range(0, width):
                oneNode = Node(i, j, darkChannel[i, j])
                nodes.append(oneNode)

        # 排序
        nodes = sorted(nodes, key=lambda node: node.value, reverse=True)

        atomsphericLight = 0

        # 原图像像素过少时,只考虑第一个像素点
        if int(percent * size) == 0:
            for i in range(0, 3):
                if img[nodes[0].x, nodes[0].y, i] > atomsphericLight:
                    atomsphericLight = img[nodes[0].x, nodes[0].y, i]

            return atomsphericLight

        # 开启均值模式
        if meanMode:
            sum = 0
            for i in range(0, int(percent * size)):
                for j in range(0, 3):
                    sum = sum + img[nodes[i].x, nodes[i].y, j]

            atomsphericLight = int(sum / (int(percent * size) * 3))
            return atomsphericLight

        # 获取暗通道前0.1%(percent)的位置的像素点在原图像中的最高亮度值
        for i in range(0, int(percent * size)):
            for j in range(0, 3):
                if img[nodes[i].x, nodes[i].y, j] > atomsphericLight:
                    atomsphericLight = img[nodes[i].x, nodes[i].y, j]

        return atomsphericLight
    @log.logged()
    def getRecoverScene(self, omega=0.97, t0=0.1, blockSize=15, meanMode=False, percent=0.01):

        imgGray = self.getMinChannel(self.img)
        imgDark = self.getDarkChannel(imgGray, blockSize=blockSize)
        atomsphericLight = self.getAtomsphericLight(imgDark,self.img,meanMode = meanMode,percent= percent)

        imgDark = np.float64(imgDark)
        transmission = 1 - omega * imgDark / atomsphericLight

        # 防止出现t小于0的情况
        # 对t限制最小值为0.1
        for i in range(0, transmission.shape[0]):
            for j in range(0, transmission.shape[1]):
                if transmission[i, j] < 0.1:
                    transmission[i, j] = 0.1

        sceneRadiance = np.zeros(self.img.shape)

        for i in range(0, 3):
            img = np.float64(self.img)
            sceneRadiance[:, :, i] = (img[:, :, i] - atomsphericLight) / transmission + atomsphericLight

            # 限制透射率 在0~255
            for j in range(0, sceneRadiance.shape[0]):
                for k in range(0, sceneRadiance.shape[1]):
                    if sceneRadiance[j, k, i] > 255:
                        sceneRadiance[j, k, i] = 255
                    if sceneRadiance[j, k, i] < 0:
                        sceneRadiance[j, k, i] = 0

        sceneRadiance = np.uint8(sceneRadiance)
        fileName = self.path.split('/')[-1].split('.')[0] + "_DC.png"
        cv2.imwrite(fileName, sceneRadiance)
        return sceneRadiance
    def __del__(self):
        localtime = time.asctime(time.localtime(time.time()))
        print(localtime, ": darkChannel Finish")
        return 0

equalizeHist.py

import time

import cv2
from utility import log
from utility import readImg

class EqualizeHist(readImg.Reader):
    @log.logged()
    def runEqualizeHist(self):
        self.reading()
        # 彩色图像均衡化,需要分解通道 对每一个通道均衡化
        (b, g, r) = cv2.split(self.img)
        bH = cv2.equalizeHist(b)
        gH = cv2.equalizeHist(g)
        rH = cv2.equalizeHist(r)
        # 合并每一个通道
        result = cv2.merge((bH, gH, rH))
        fileName = self.path.split('/')[-1].split('.')[0] + "_EqH.png"
        cv2.imwrite(fileName,result)
    def __del__(self):
        localtime = time.asctime(time.localtime(time.time()))
        print(localtime, ": equalizeHist Finish")
        return 0
if __name__ == '__main__':
    eq = EqualizeHist('./233.png')
    eq.runEqualizeHist()

log.py

from functools import wraps, partial
import logging
import time


def logged(func=None, *, level=logging.DEBUG, name=None, message=None):
    if func is None:
        return partial(logged, level=level, name=name, message=message)

    logname = name if name else func.__module__
    log = logging.getLogger(logname)
    logmsg = message if message else func.__name__

    @wraps(func)
    def wrapper(*args, **kwargs):
        log.log(level, logmsg)
        return func(*args, **kwargs)

    localtime = time.asctime(time.localtime(time.time()))
    print(localtime,": 正在加载{} {}".format(logmsg,logname))
    return wrapper

readImg.py

import cv2
import numpy as np

class Reader():

    def __init__(self,imgPath):
        self.path = imgPath

    def reading(self):
        self.img = cv2.imread(self.path)
        return 0


statistics.py

import json
import matplotlib.pyplot as plt
import seaborn as sns

class Statistic():
    def __init__(self):
        self.note = ""

    def statistic_picture(self):
        sns.set()
        plt.rcParams['font.sans-serif'] = ['Arial Unicode MS']
        data = {'countTagId': {},
                'tagName': {}}
        with open('./imageInfo.pic','r') as fb:
            js = json.load(fb)
        for i in js['BOOKID']:
            tags = js['BOOKID'][i]['expand'].get('sysTags')
            if tags is not  None:
                for j in tags:
                    if data['countTagId'].get(j['tagName']) == None:
                        data['countTagId'][j['tagName']] = 0
                    data['countTagId'][j['tagName']] += 1
                    data['tagName'][j['sysTagId']] =j['tagName']
        print(data)
        a = sorted(data['countTagId'].items(), key=lambda x: x[1], reverse=True)
        print(a[:12])

        #
        pie, ax = plt.subplots(figsize=[10,6])
        labels = [u'{}'.format(i[0]) for i in a[:12]]
        x = [i[1] for i in a[:12]]
        plt.pie(x=x, autopct="%.1f%%", labels=labels, pctdistance=0.5)
        plt.title("Top12", fontsize=14)
        pie.savefig("PopularTags.png")



你知道雪为什么是白色的吗?因为她忘记了原来的颜色