out_channels; pad=1, bias=true) norm1 = norm(out_channels) dilations = length(dilations) == 0 ? [1 f"> out_channels; pad=1, bias=true) norm1 = norm(out_channels) dilations = length(dilations) == 0 ? [1 f"> out_channels; pad=1, bias=true) norm1 = norm(out_channels) dilations = length(dilations) == 0 ? [1 f">
using Flux: output_size, train!
using Base: _before_colon, concatenate_setindex!, copymutable
using Flux, BSON
using Revise
using BenchmarkTools
using InteractiveUtils
using TimerOutputs
const to = TimerOutput()

xshape(x) = eltype(x) <: AbstractArray ? (length(x), xshape(x[1])...) : [length(x)]
tovec(x) = eltype(x) <: AbstractArray ? tovec([xx_ for x_ in x for xx_ in x_]) : [x...]

struct DownConv{C1<:Conv,C2<:Conv, BI<:Union{BatchNorm,InstanceNorm},M<:MaxPool,F<:Function}
    conv1::C1
    conv2::Vector{C2}
    act::F
    norm1::BI
    bn::Vector{BI}
    residual::Bool
    pooling:: Bool
    pool::M
end
function DownConv(in_channels, out_channels, blocks; pooling=true, norm=Flux.BatchNorm, act=Flux.relu, residual=true, dilations=[])
    # norm = norm == "bn" ? Flux.BatchNorm : norm == "in" ? Flux.InstanceNorm : error("Unknown type:\\t$norm")
    conv1 = Conv((3,3), in_channels=>out_channels; pad=1, bias=true)
    norm1 = norm(out_channels)
    dilations = length(dilations) == 0 ? [1 for i in 1:blocks] : dilations
    conv2 =[Conv((3,3), out_channels=>out_channels; dilation=dilations[i], pad=dilations[i]) 
        for i in 1:blocks ]

    bn = fill(norm(out_channels), blocks)
    
    pool = Flux.MaxPool((2, 2), stride=2)
    # @show typeof(conv1), eltype(conv2), typeof(norm1), typeof(pool), typeof(act)
    DownConv{typeof(conv1), eltype(conv2), typeof(norm1), typeof(pool), typeof(act)}(conv1, conv2, act, norm1, bn, residual, pooling, pool) #
end

function (m::DownConv)(x::AbstractArray{Float32, 4})
    x1 = m.act.(m.norm1(m.conv1(x)))
    for (idx, conv) in enumerate(m.conv2)
        x2 = conv(x1)
        x2 = m.bn[idx](x2)
        x1 = m.residual ? x1+x2 : x2
        x1 = m.act.(x1)
    end
    before_pool = deepcopy(x1)
    x1 = m.pooling ? m.pool(x1) : x1

    x1, before_pool
end

Flux.@functor DownConv

struct UpConv{C<:Conv,BI<:Union{BatchNorm,InstanceNorm},F<:Function}
    up_conv::C
    conv1::C
    conv2::Vector{C}
    bn::Vector{BI}
    norm0::BI
    norm1::BI
    act::F
    concat::Bool
    use_mask::Bool
    residual::Bool
end
function UpConv(in_channels, out_channels, blocks; residual=true, norm=Flux.BatchNorm, act=Flux.relu, concat=true, use_att=false, use_mask=false, dilations=[], out_fuse=false)
    up_conv = Flux.Conv((3,3), in_channels=>out_channels; pad=1, bias=true)
    norm0 = norm(out_channels)
    if length(dilations)==0
        dilations = [1 for _ in 1:blocks]
    end
    if concat
        conv1 = Conv((3,3), (2*out_channels+(use_mask ? 1 : 0))=>out_channels; pad=1, bias=true)
        norm1 = norm(out_channels)
    else
        conv1 = Conv((3,3), out_channels=>out_channels; pad=1, bias=true)
        norm1 = norm(out_channels)
    end
    conv2 = [ Conv((3,3), out_channels=>out_channels; dilation = dilations[i], pad=dilations[i], bias=true)     for i in 1:blocks ]
    bn = [norm(out_channels) for _ in 1:blocks]
    UpConv{typeof(up_conv), typeof(norm0), typeof(act)}(up_conv, conv1, conv2, bn, norm0, norm1, act, concat, use_mask, residual)
end
struct OutFuse{v}
end
OutFuse(x) = OutFuse{x}()
function (m::UpConv)(::OutFuse{true}, from_up::AbstractArray{Float32, 4}, from_down::AbstractArray{Float32, 4}; mask=nothing, se=nothing)#::Tuple{CuArray{Float32, 4, CUDA.Mem.DeviceBuffer}, CuArray{Float32, 4, CUDA.Mem.DeviceBuffer}}
    from_up = m.act.(m.norm0(m.up_conv(Flux.upsample_bilinear(from_up, (2.0f0,2.0f0)))))
    x1 = m.concat ? (m.use_mask ? cat(from_up, from_down, mask; dims=Val(3)) : cat(from_up, from_down; dims=Val(3))) : (prod(size(from_down))!=0 ? from_up + from_down : from_up)
    xfuse = x1 = m.act.(m.norm1(m.conv1(x1)))
    for (idx, conv) in enumerate(m.conv2)
        x2 = m.bn[idx](conv(x1))
        if !(se===nothing) && idx == length(m.conv2)
            x2 = se(x2)
        end
        x1 = m.residual ? x1+x2 : x2
        x1 = m.act.(x1)
    end
    x1, xfuse
end

function (m::UpConv)(::OutFuse{false}, from_up::AbstractArray{Float32, 4}, from_down::AbstractArray{Float32, 4}; mask=nothing, se=nothing)#::CuArray{Float32, 4, CUDA.Mem.DeviceBuffer}
    from_up = m.act.(m.norm0(m.up_conv(Flux.upsample_bilinear(from_up, (2.0f0,2.0f0)))))

    x1 = m.concat ? (m.use_mask ? cat(from_up, from_down, mask; dims=Val(3)) : cat(from_up, from_down; dims=Val(3))) : (prod(size(from_down))!=0 ? from_up + from_down : from_up)

    x1 = m.act.(m.norm1(m.conv1(x1)))

    for (idx, conv) in enumerate(m.conv2)
        x2 = m.bn[idx](conv(x1))
        if !(se===nothing) && idx == length(m.conv2)
            x2 = se(x2)
        end
        x1 = m.residual ? x1+x2 : x2
        x1 = m.act.(x1)
    end
    x1
end

Flux.@functor UpConv

struct CFFBlock{C1<:Conv, C2<:Conv, D1<:DownConv, D2<:DownConv, D3<:DownConv, CH1<:Chain, CH2<:Chain}
    up32::C1
    up31::C2
    down1::D1
    down2::D2
    down3::D3
    conv22::CH1
    conv33::CH2
end

function CFFBlock(; down=DownConv, up=UpConv, ngf::Int=32) 
    p = [
        Conv((3,3), ngf*4 => ngf*1, x->Flux.leakyrelu(x, eltype(x)(0.2)); stride=1, pad=1),
        Conv((3,3), ngf*4=>ngf*1, x->Flux.leakyrelu(x, eltype(x)(0.2)); stride=1, pad=1),
        DownConv(ngf, ngf, 3; pooling=true, norm=Flux.InstanceNorm, act=x->Flux.leakyrelu(x,eltype(x)(0.01)), dilations=[]),
        DownConv(ngf, ngf*2, 3; pooling=true, norm=Flux.InstanceNorm, act=x->Flux.leakyrelu(x,eltype(x)(0.01)), dilations=[]),
        DownConv(ngf*2, ngf*4, 3; pooling=false, norm=Flux.InstanceNorm, act=x->Flux.leakyrelu(x,eltype(x)(0.01)), dilations=[1,2,5]),
        Chain(
            Conv((3,3), ngf*2=>ngf, x->Flux.leakyrelu(x, eltype(x)(0.2)); stride=1, pad=1),
            Conv((3,3), ngf=>ngf, x->Flux.leakyrelu(x, eltype(x)(0.2)); stride=1, pad=1)
        ),
        Chain(
            Conv((3,3), ngf*4=>ngf*2, x->Flux.leakyrelu(x, eltype(x)(0.2)); stride=1, pad=1),
            Conv((3,3), ngf*2=>ngf*2, x->Flux.leakyrelu(x, eltype(x)(0.2)); stride=1, pad=1)
        )
    ]
    CFFBlock{typeof.(p)...}(p...)
end

function (m::CFFBlock)(x1::AbstractArray{Float32, 4}, x2::AbstractArray{Float32, 4}, x3::AbstractArray{Float32, 4})
    x32 = Flux.upsample_bilinear(x3; size=(size(x2)[2], size(x2)[1]))
    x32 = m.up32(x32)
    x31 = Flux.upsample_bilinear(x3; size=(size(x1)[2], size(x1)[1]))
    x31 = m.up31(x31)
    # cross-connection
    x, d1 = m.down1(x1 + x31)
    x, d2 = m.down2(x + m.conv22(x2) + x32)

    d3, _ = m.down3(x + m.conv33(x3))
    d1,d2,d3
end

Flux.@functor CFFBlock

struct ECABlocks{C<:Conv,AMP<:Flux.AdaptiveMeanPool}
    conv::C
    avg_pool::AMP
end

function ECABlocks(channel, k_size=3)
    conv = Conv((k_size,), 1=>1, sigmoid; pad=(k_size-1)÷2, bias=false)
    avg_pool = AdaptiveMeanPool((1,1))
    ECABlocks{typeof(conv), typeof(avg_pool)}(conv, avg_pool)
end
function (m::ECABlocks)(x::AbstractArray{Float32, 4})
    # h, w, c, b = size(x)

    y = m.avg_pool(x)
    y = Flux.unsqueeze(permutedims(m.conv(permutedims(reshape(y, size(y)[2:end]), (2,1,3))), (2, 1, 3)), dims=1)
    # y = Flux.sigmoid(y) 在这里可以放到conv里去
    x .* y
end
Flux.@functor ECABlocks

struct MBEBlock{C<:Conv,BI<:Union{BatchNorm,InstanceNorm}, CH<:Chain,F<:Function}
    up_conv::C
    bn::Vector{BI}
    norm0::BI
    norm1::BI
    conv1::C
    conv2::Vector{CH}
    conv3::Vector{C}
    act::F
    concat::Bool
    residual::Bool
end

function MBEBlock(in_channels=512, out_channels=3; norm=Flux.BatchNorm, act=Flux.relu, blocks=1, residual=true, concat=true, is_final=true)
    up_conv = Flux.Conv((3,3), in_channels=>out_channels; pad=1, bias=true)
    conv1 = Conv((3,3), (concat ? 2*out_channels : out_channels)=>out_channels; pad=1, bias=true)
    conv2 = Vector{Chain}()
    conv3 = Vector{Conv}()
    for i in 1:blocks
        push!(conv2, Chain(
            Conv((5,5), (out_channels ÷ 2 + 1)=>(out_channels ÷ 4), Flux.relu; stride=1, pad=2, bias=true),
            Conv((5,5), (out_channels ÷ 4)=>1, Flux.sigmoid_fast; stride=1, pad=2, bias=true)
        ))
        push!(conv3, Conv((3,3), (out_channels ÷ 2)=>out_channels; pad=1, bias=true))
    end
    bn = [norm(out_channels) for i in 1:blocks]

    MBEBlock{typeof(up_conv), eltype(bn), eltype(conv2), typeof(act)}(up_conv, bn, norm(out_channels), norm(out_channels), conv1, conv2, conv3, act, concat, residual)
end

function (m::MBEBlock)(from_up::AbstractArray{Float32, 4}, from_down::AbstractArray{Float32, 4}; mask=nothing)
    @timeit to "3.3.1" m.act(m.norm0(m.up_conv(Flux.upsample_bilinear(from_up, (2.0f0, 2.0f0)))))
    from_up = m.act(m.norm0(m.up_conv(Flux.upsample_bilinear(from_up, (2.0f0, 2.0f0)))))
    @timeit to "3.3.2" if m.concat
        x1 = cat(from_up, from_down; dims=Val(3))
    elseif !(from_down === nothing)
        x1 = from_up + from_down
    else
        x1 = from_up
    end
    @timeit to "3.3.3" m.act(m.norm1(m.conv1(x1)))
    x1 = m.act(m.norm1(m.conv1(x1)))
    #residual structure
    H, W, C, _ = size(x1)
    for (idx, (conv1, conv2)) in enumerate(zip(m.conv2, m.conv3))
        #@show size(x1)
        #@show size(x1[:,:,1:(C ÷ 2 ), :])
        #@show size(mask)
        @timeit to "3.3.4" conv1(cat(view(x1, :,:,1:(C ÷ 2), :), mask; dims=Val(3)))
        mask = conv1(cat(view(x1, :,:,1:(C ÷ 2), :), mask; dims=Val(3)))
        @timeit to "3.3.5" view(x1, :,:,(C ÷ 2+1) : C, :) .*mask
        x2_actv = view(x1, :,:,(C ÷ 2+1) : C, :) .*mask
        @timeit to "3.3.6" conv2(view(x1, :,:,(C ÷ 2+1) : C, :) + x2_actv)
        x2 = conv2(view(x1, :,:,(C ÷ 2+1) : C, :) + x2_actv)
        x2 = m.bn[idx](x2)
        x1 = m.residual ? x2+x1 : x2
        x1 = m.act.(x1)
    end
    x1
end
Flux.@functor MBEBlock

struct SelfAttentionSimple{C<:Conv, CH<:Chain, AA<:AbstractArray{Float32, 4}}
    k_center::Int32
    q_conv::C
    k_conv::C
    v_conv::C
    sim_func::C
    out_conv::CH
    min_area::Float32
    threshold::Float32
    k_weight::AA

end

function SelfAttentionSimple(in_channel, k_center)
    conv1 = Conv((1, 1), in_channel=>in_channel)
    conv2 = Conv((1, 1), in_channel=>in_channel*k_center)
    conv3 = Conv((1, 1), in_channel=>in_channel*k_center)
    conv4 = Conv((1,1), (2*in_channel) => 1; stride=1, pad=0, bias=true)
    ch = Chain(
        Conv((3,3), in_channel=>(in_channel÷8), Flux.relu; stride=1, pad=1),
        Conv((3,3), (in_channel÷8)=>1; stride=1, pad=1)
    )
    a = fill(1.0f0, (1, 1, k_center, 1))
    SelfAttentionSimple{typeof(conv1), typeof(ch), typeof(a)}(Int32(k_center), 
        conv1, conv2, conv3, conv4, ch,
        100.0f0, 0.5f0,
        a
    )
end

function compute_attention(m::SelfAttentionSimple, query::T, key::T, mask::T, eps=1) where{T}#::CuArray{Float32, 4, CUDA.Mem.DeviceBuffer}
    h,w,c,b = size(query)
    #@show size(query)
    # @btime $m.q_conv($query)
    query = m.q_conv(query)
    key_in = key
    # @btime $m.k_conv($key_in)
    key = m.k_conv(key_in)
    # keys = [view(key, :,:,i:(i+c-1),:) for i in 1:c:size(key)[3]] 
    keys = Vector{T}([key[:,:,i:(i+c-1),:] for i in 1:c:size(key)[3]])
    # @btime eltype($mask).($mask .> $m.threshold)
    importance_map = eltype(mask).(mask .> m.threshold)
    # @btime sum($importance_map, dims=[1,2])
    s_area = sum(importance_map, dims=[1,2])
    # mask = s_area .>= m.min_area
    # s_area = s_area .* mask + m.min_area .* .!mask
    # @btime clamp_lo(one(eltype($s_area))*$m.min_area).($s_area)
    clamp!(s_area, m.min_area, 1.0f7) #clamp_lo(one(eltype(s_area))*m.min_area).(s_area)

    s_area = s_area[:, :, 1:1, :]#view(s_area, :, :, 1:1, :)

    if m.k_center != 2
        # @btime [sum(k .*$importance_map, dims=[1, 2]) ./ $s_area  for k in $keys]
        ks = Vector{T}() 
        for k in keys
            push!(ks, sum(k .*importance_map, dims=(Val(1), Val(2))) ./ s_area)  
        end
        keys = ks
        # keys = [sum(keys[1] .*importance_map, dims=[1, 2]) ./ s_area, sum(keys[2] .*importance_map, dims=[1, 2]) ./ s_area]
    else
        # @btime [sum($keys[1] .* $importance_map, dims=[1,2]) ./$s_area,
        # sum($keys[2] .*(one(eltype($importance_map)) .- $importance_map), dims=[1,2]) ./ (size($keys[2])[1] * size($keys[2])[2] .- $s_area .+ $eps),
        # ]
        keys = Vector{T}([sum(keys[1] .* importance_map, dims=[1,2]) ./s_area,
        sum(keys[2] .*(one(eltype(importance_map)) .- importance_map), dims=[1,2]) ./ (size(keys[2])[1] * size(keys[2])[2] .+ eps .- s_area)
        ])
    end

    f_query = query
    
    # f_key =  [repeat(reshape(k, (1,1,c,b)), size(f_query)[1:2]..., 1, 1) for k in keys]
    f_key = [Flux.upsample_nearest(reshape(k, (1, 1, c, b)), size=(size(f_query)[1], size(f_query)[2])) for k in keys]
 
    attention_scores = Vector{T}()
    for k in f_key
        # @btime Flux.tanh_fast.(cat($f_query, $k, dims=3))
        combine_qk = Flux.tanh_fast.(cat(f_query, k, dims=Val(3)))
        # combine_qk = typeof(f_query)(undef, size(f_query)[1], size(f_query)[2], size(f_query)[3]*2, size(f_query)[4])
        # combine_qk[:,:,1:size(f_query)[3],:] = f_query
        # combine_qk[:,:,(size(f_query)[3]+1):end,:] = k
        # @btime $m.sim_func($combine_qk)
        sk = m.sim_func(combine_qk)
        push!(attention_scores, sk)
    end
    # @btime cat($attention_scores...; dims=3)
    # s = cat(attention_scores...; dims=Val(3))
    s =  reshape(mapreduce(Flux.flatten, vcat, attention_scores), size(attention_scores[1])[1], size(attention_scores[1])[2], sum([size(x)[3] for x in attention_scores]), size(attention_scores[1])[4])
    # @btime permutedims($s, (3,1,2,4))
    s = permutedims(s, (3,1,2,4))
    # @btime $m.v_conv($key_in)
    v =  m.v_conv(key_in)
    if m.k_center == 2
        # @btime sum($v[:, :, 1:$c-1, :] .* $importance_map, dims=[1,2]) ./ $s_area
        v_fg =  sum(view(v, :, :, 1:c-1, :), dims=[1,2]) .* sum(importance_map, dims=[1,2]) ./ s_area
        # @btime sum($v[:,:,$c:end, :] .* (1 .- $importance_map);dims=[1,2]) ./ (size($v)[1] * size($v)[2] .- $s_area .+ $eps)
        v_bg =  sum(view(v, :,:,c:size(v)[3], :), dims=[1,2]) .* sum((1 .- importance_map);dims=[1,2]) ./ (size(v)[1] * size(v)[2] .- s_area .+ eps)
        v = cat(v_fg, v_bg; dims=Val(3))
    else
        # @btime sum($v .* $importance_map, dims=[1,2]) ./ $s_area
        v = sum(v, dims=[1,2]) .* sum(importance_map, dims=[1,2]) ./ s_area
    end
    v = reshape(v, (c, m.k_center, b))
    #@show size(s), (m.k_center, h*w, b)
    #@show size(v)
    # @btime permutedims(reshape(Flux.batched_mul($v, reshape($s, ($m.k_center, $h*$w, $b))), ($c,$h,$w,$b)), (2, 3, 1,4))
    attn = permutedims(reshape(Flux.batched_mul(v, reshape(s, (m.k_center, h*w, b))), (c,h,w,b)), (2, 3, 1,4))
    #@show size(query)
    #@show size(attn)
    # @btime $m.out_conv($attn+$query)
    m.out_conv(attn + query)
end

function (m::SelfAttentionSimple)(xin::AbstractArray{Float32, 4}, xout::AbstractArray{Float32, 4}, xmask::AbstractArray{Float32, 4})
    h, w, c, b_num= size(xin)
    attention_score = compute_attention(m, xin, xout, xmask)
    attention_score = reshape(attention_score, (h, w, 1, b_num))
    return xout, Flux.sigmoid.(attention_score)
end
Flux.@functor SelfAttentionSimple

struct SMRBlock{U<:UpConv, C<:Conv, S<:SelfAttentionSimple}
    upconv::U
    primary_mask::C
    self_calibrated::S
end

function SMRBlock(ins, outs, k_center; norm=Flux.BatchNorm, act=Flux.relu, blocks=1, residual=true, concat=true)
    conv1 = UpConv(ins, outs, blocks; residual=residual, concat=concat, norm=norm, act=act)
    conv2 = Conv((1,1), outs=>1, Flux.sigmoid_fast; stride=1, pad=0, bias=true)
    sa = SelfAttentionSimple(outs, k_center)
    SMRBlock{typeof(conv1), typeof(conv2), typeof(sa)}( conv1, conv2, sa)
end

function (m::SMRBlock)(input::AbstractArray{Float32, 4}, encoder_outs=nothing)
    # @btime $m.upconv(OutFuse(true), $input, $encoder_outs)
    mask_x, _ = m.upconv(OutFuse(true), input, (encoder_outs===nothing ? typeof(input)(undef, zeros(Int64, length(size(input)))) : encoder_outs))
    # @btime $m.primary_mask($mask_x)
    primary_mask = m.primary_mask(mask_x)
    # # @btime $m.self_calibrated($mask_x, $mask_x, $primary_mask)
    mask_x, self_calibrated_mask = m.self_calibrated(mask_x, mask_x, primary_mask)
    return Dict(
        "feats"=>[mask_x],
        "attn_maps"=>[primary_mask, self_calibrated_mask]
    )
end
Flux.@functor SMRBlock

struct CoarseEncoder{D<:DownConv}
    down_convs::Vector{D}
end
function CoarseEncoder(in_channels::Int=3, depth::Int=3; blocks=1, start_filters=32, residual=true, norm=Flux.BatchNorm, act=Flux.relu)
    down_convs = DownConv[]
    outs = nothing
    if isa(blocks, AbstractArray)
        blocks = blocks[0]
    end

    for i in 1:depth
        ins = i==1 ? in_channels : outs
        outs = start_filters*(2^(i-1))
        pooling = true
        # #@show ins, depth
        push!(down_convs, DownConv(ins, outs, blocks, pooling=pooling, residual=residual, norm=norm, act=act))
    end
    @show typeof(down_convs)
    CoarseEncoder{DownConv}(down_convs)
end

function (m::CoarseEncoder)(x::AbstractArray{Float32, 4})
    nx = x
    encoder_outs = Vector{typeof(x)}()
    for d_conv in m.down_convs
        nx, before_pool = d_conv(nx)
        push!(encoder_outs, before_pool)
    end
    nx, encoder_outs
end
Flux.@functor CoarseEncoder

struct SharedBottleNeck{U<:UpConv, D<:DownConv, ECA1<:ECABlocks, ECA2<:ECABlocks}
    up_convs::Vector{U}
    down_convs::Vector{D}
    up_im_atts::Vector{ECA1}
    up_mask_atts::Vector{ECA2}
end
function SharedBottleNeck(in_channels=512, depth=5, shared_depth=2; start_filters=32, blocks=1, residual=true, concat=true, norm=Flux.BatchNorm, act=Flux.relu, dilations=[1,2,5])
    @show in_channels
    @show depth
    @show shared_depth
    @show start_filters
    @show blocks
    @show residual
    @show concat
    @show norm
    @show act
    @show dilations
    start_depth = depth - shared_depth
    max_filters = 512
    down_convs = Vector{DownConv}()
    up_convs = Vector{UpConv}()
    up_im_atts = Vector{ECABlocks}()
    up_mask_atts = Vector{ECABlocks}()
    outs = 0
    println("construct SharedBottleNeck")
    for i in start_depth:depth-1
        @show i, start_depth, depth
        ins = i == start_depth ? in_channels : outs 
        outs = min(ins*2, max_filters)
        # encoder convs
        pooling = i<depth-1 ? true : false
        push!(down_convs, DownConv(ins, outs, blocks, pooling=pooling, residual=residual, norm=norm, act=act, dilations=dilations))
        # decoder convs
        if i < depth - 1
            @show i,depth
            up_conv = UpConv(min(outs*2, max_filters), outs, blocks, residual=residual, concat=concat, norm=norm, act=Flux.relu, dilations=dilations)
            println(typeof(up_conv))
            push!(up_convs, up_conv)
            println(typeof(ECABlocks(outs)))
            push!(up_im_atts, ECABlocks(outs))
            push!(up_mask_atts, ECABlocks(outs))
        end
    end
    # @show eltype(up_convs), eltype(down_convs), eltype(up_im_atts), eltype(up_mask_atts)
    SharedBottleNeck{eltype(up_convs), eltype(down_convs), eltype(up_im_atts), eltype(up_mask_atts)}(up_convs, down_convs, up_im_atts, up_mask_atts)
end

function (m::SharedBottleNeck)(input::AbstractArray{Float32, 4})
    # encoder convs
    im_encoder_outs = Vector{typeof(input)}()
    mask_encoder_outs = Vector{typeof(input)}()
    x = input
    for (i, d_conv) in enumerate(m.down_convs)
        x, before_pool = d_conv(x)
        push!(im_encoder_outs, before_pool)
        push!(mask_encoder_outs, before_pool)
    end
    x_im = x

    x_mask = x
    #@show size(x_mask)
    # Decoder convs
    x = x_im
    for (i, (up_conv::eltype(m.up_convs), attn::eltype(m.up_im_atts))) in enumerate(zip(m.up_convs, m.up_im_atts))
        before_pool = im_encoder_outs === nothing ? typeof(x)(undef, zeros(Int64, length(size(x)))) : im_encoder_outs[end-i]
        x = up_conv(OutFuse(false), x, before_pool, se=attn)
    end
    x_im = x

    x = x_mask
    for (i, (up_conv::eltype(m.up_convs), attn::eltype(m.up_mask_atts))) in enumerate(zip(m.up_convs, m.up_mask_atts))
        before_pool = mask_encoder_outs === nothing ? typeof(x)(undef, zeros(Int64, length(size(x)))) : mask_encoder_outs[end-i]
        x = up_conv(OutFuse(false), x, before_pool, se=attn)
    end
    x_mask = x

    x_im, x_mask
end
Flux.@functor SharedBottleNeck

struct CoarseDecoder{C<:Conv, MBE<:MBEBlock, SMR<:SMRBlock, ECA<:ECABlocks}
    up_convs_bg::Vector{MBE}
    up_convs_mask::Vector{SMR}
    atts_mask::Vector{ECA}
    atts_bg::Vector{ECA}
    conv_final_bg::C
    use_att::Bool
end

function CoarseDecoder(in_channels=512, out_channels=3, k_center=2; norm=Flux.BatchNorm, act=Flux.relu, depth=5, blocks=1, residual=true, concat=true, use_att=false)
    up_convs_bg = Vector{MBEBlock}()
    up_convs_mask = Vector{SMRBlock}()
    atts_bg = Vector{ECABlocks}()
    atts_mask = Vector{ECABlocks}()
    outs = in_channels
    for i in 1:depth
        ins = outs 
        outs = ins ÷ 2
        # background reconstruction branch
        up_conv = MBEBlock(ins, outs; blocks=blocks, residual=residual, concat=concat, norm=Flux.InstanceNorm, act=act)
        push!(up_convs_bg, up_conv)
        if use_att
            push!(atts_bg, ECABlocks(outs))
        end
        #mask prediction branch
        up_conv = SMRBlock(ins, outs, k_center; norm=norm, act=act, blocks=blocks, residual=residual, concat=concat)
        push!(up_convs_mask, up_conv)
        if use_att
            push!(atts_mask, ECABlocks(outs))
        end
    end
    conv_final_bg = Conv((1,1), outs=>out_channels, stride=1, pad=0, bias=true)
    CoarseDecoder{typeof(conv_final_bg),eltype(up_convs_bg),eltype(up_convs_mask), eltype(atts_mask)}(up_convs_bg, up_convs_mask, atts_mask, atts_bg, 
        conv_final_bg,
        use_att
    )
end

function (m::CoarseDecoder)(bg::T, fg, mask::T, encoder_outs=nothing) where{T}
    bg_x = bg
    mask_x = mask
    mask_outs = Vector{T}()
    bg_outs = Vector{T}()
    for (i, (up_bg, up_mask)) in enumerate(zip(m.up_convs_bg, m.up_convs_mask))
        # @btime before_pool = $encoder_outs[end-($i-1)]
        # before_pool = encoder_outs===nothing ? nothing : encoder_outs[end-(i-1)] #encoder_outs[end-(i-1)] #
        if m.use_att
            before_pool = encoder_outs[end-(i-1)] 
            # @btime $m.atts_mask[$i]($before_pool)
            @timeit to "3.1" m.atts_mask[i](before_pool)
            mask_before_pool = m.atts_mask[i](before_pool)
            # @btime $m.atts_bg[$i]($before_pool)
            @timeit to "3.2" m.atts_bg[i](before_pool)
            bg_before_pool = m.atts_bg[i](before_pool)
            @show size(bg_before_pool)
        end
        # @btime $up_mask($mask_x,$mask_before_pool)
        # @code_warntype up_mask(mask_x, mask_before_pool)
        @timeit to "3.3" up_mask(mask_x, mask_before_pool)
        smr_outs = up_mask(mask_x, mask_before_pool)
        # @btime mask_x = $smr_outs["feats"][1]
        mask_x = smr_outs["feats"][1]
        # @btime primary_map, self_calibrated_map = $smr_outs["attn_maps"]
        primary_map, self_calibrated_map = smr_outs["attn_maps"]
        @show size(self_calibrated_map)
        # @btime push!($mask_outs, $primary_map)
        push!(mask_outs, primary_map)
        # @btime push!($mask_outs, $self_calibrated_map)
        push!(mask_outs, self_calibrated_map)
        # @btime $up_bg($bg_x, $bg_before_pool, $self_calibrated_map)
        # @show typeof(bg_x), typeof(bg_before_pool), typeof(self_calibrated_map)
        @timeit to "3.4" up_bg(bg_x, bg_before_pool; mask = self_calibrated_map)
        bg_x = up_bg(bg_x, bg_before_pool; mask = self_calibrated_map) # 这里可能有问题
        @show size(bg_x)
        # @btime push!($bg_outs, $bg_x)
        push!(bg_outs, bg_x)
    end
    if m.conv_final_bg !== nothing
        # @btime $m.conv_final_bg($bg_x)
        @timeit to "3.5" m.conv_final_bg(bg_x)
        bg_x = m.conv_final_bg(bg_x)
        # @btime push!($bg_outs, $bg_x)
        push!(bg_outs, bg_x)
        push!(mask_outs, mask_outs[end])
    end
    #@show length(bg_outs)
    #@show length(mask_outs)
    return bg_outs, mask_outs, nothing
end
Flux.@functor CoarseDecoder

struct Refinement{C<:Conv, CH1<:Chain, CH2<:Chain, CH3<:Chain, CH4<:Chain,D1<:DownConv,D2<:DownConv,D3<:DownConv,CFF<:CFFBlock}
    conv_in::CH1
    dec_conv2::C
    dec_conv3::CH2
    dec_conv4::CH3
    down1::D1
    down2::D2
    down3::D3
    cff_blocks::Vector{CFF}
    out_conv::CH4
    n_skips::Int64
end

function Refinement(;in_channels=3, out_channels=3, shared_depth=2, down=DownConv, up=UpConv, ngf=32, n_cff=3, n_skips=3)
    conv_in = Chain(
            Conv((3,3), in_channels=>ngf; stride=1, pad=1, bias=true),
            Flux.InstanceNorm(ngf),
            x->Flux.leakyrelu.(x, eltype(x)(0.2))
        )
    dec2 = Conv((1,1), ngf=>ngf; stride=1, pad=0, bias=true)
    dec3 = Chain(
            Conv((1,1), (ngf*2)=>ngf, x->Flux.leakyrelu(x, eltype(x)(0.2)); stride=1, pad=0, bias=true),
            Conv((3,3), ngf=>ngf, x->Flux.leakyrelu(x, eltype(x)(0.2)); stride=1, pad=1, bias=true)
        )
    dec4 = Chain(
            Conv((1,1), (ngf*4)=>(ngf*2), x->Flux.leakyrelu(x, eltype(x)(0.2)); stride=1, pad=0, bias=true),
            Conv((3,3), (ngf*2)=>(ngf*2), x->Flux.leakyrelu(x, eltype(x)(0.2)); stride=1, pad=1, bias=true)
        )
    down1 = down(ngf, ngf, 3, pooling=true, norm=Flux.InstanceNorm, act=x->Flux.leakyrelu(x, eltype(x)(0.01)), dilations=[])
    down2 = down(ngf, (ngf*2), 3, pooling=true, norm=Flux.InstanceNorm, act=x->Flux.leakyrelu(x, eltype(x)(0.01)), dilations=[])
    down3 = down((ngf*2), (ngf*4), 3, pooling=false, norm=Flux.InstanceNorm, act=x->Flux.leakyrelu(x, eltype(x)(0.01)), dilations=[1,2,5])
    cffs = [CFFBlock(;ngf=ngf) for i in 1:n_cff]
    out_conv = Chain(
            Conv((3,3), (ngf+ngf*2+ngf*4)=>ngf; stride=1, pad=1, bias=true),
            Flux.InstanceNorm(ngf),
            x->Flux.leakyrelu.(x, eltype(x)(0.2)),
            Conv((1,1), ngf=>out_channels; stride=1, pad=0)
        )
    Refinement{typeof(dec2), typeof(conv_in), typeof(dec3), typeof(dec4), typeof(out_conv), typeof(down1), typeof(down2), typeof(down3), eltype(cffs)}(
        conv_in, dec2, dec3, dec4, down1, down2, down3, cffs, out_conv, n_skips )
end

function (m::Refinement)(input::AbstractArray{Float32, 4}, coarse_bg::AbstractArray{Float32, 4}, mask::AbstractArray{Float32, 4}, encoder_outs, decoder_outs::Vector{T} where T<:AbstractArray{Float32, 4}) 
    xin = cat(coarse_bg, mask, dims=Val(3))
    # @btime $m.conv_in($xin)
    x = m.conv_in(xin)
    # @btime $m.dec_conv2($decoder_outs[1])
    m.n_skips < 1 && (x += m.dec_conv2(decoder_outs[1]))
    # @btime $m.down1($x)
    x,d1 = m.down1(x)
    # @btime $m.dec_conv3($decoder_outs[2])
    m.n_skips < 2 && (x += m.dec_conv3(decoder_outs[2]))
    # @btime $m.down2($x)
    x,d2 = m.down2(x)
    # @btime $m.dec_conv4($decoder_outs[3])
    m.n_skips < 3 && (x += m.dec_conv4(decoder_outs[3]))
    # @btime $m.down3($x)
    x,d3 = m.down3(x)

    for block in m.cff_blocks
        # @btime $block($xs)
        d1,d2,d3 = block(d1,d2,d3)
    end
    # @btime [Flux.upsample_bilinear(x_hr; size=(size($coarse_bg)[2], size($coarse_bg)[1])) for x_hr in $xs]
    xs = [Flux.upsample_bilinear(x_hr; size=(size(coarse_bg)[2], size(coarse_bg)[1])) for x_hr in (d1,d2,d3)]

    # @btime $m.out_conv(cat($xs..., dims=3))
    xct = Base.cat_t(eltype(xs[1]), xs...; dims=3)
    im = m.out_conv(xct)
end
Flux.@functor Refinement

struct SLBR{CE<:CoarseEncoder, SB<:SharedBottleNeck, CD<:CoarseDecoder, RF<:Refinement}
    encoder::CE
    shared_decoder::SB
    coarse_decoder::CD
    refinement::RF
    long_skip::Bool
end

function SLBR(; in_channels=3, depth=5, shared_depth=2, blocks=[1 for i in 1:5], out_channels_image=3, out_channels_mask=1, start_filters=32, residual=true, concat=true, long_skip=false, k_refine=3, n_skips=3, k_center=3)
    println("construct SLBR")
    @show in_channels
    @show depth
    @show shared_depth
    @show blocks
    @show out_channels_image
    @show out_channels_mask
    @show start_filters
    @show residual
    @show concat
    @show long_skip
    @show k_refine
    @show n_skips
    @show k_center
    encoder = CoarseEncoder(in_channels, depth-shared_depth; blocks=blocks[1], start_filters=start_filters, residual=residual, norm=Flux.BatchNorm, act=Flux.relu)
    shared_decoder = SharedBottleNeck(start_filters*2^(depth-shared_depth-1), depth, shared_depth; blocks=blocks[5], residual=residual, concat=concat, norm=Flux.InstanceNorm)
    coarse_decoder = CoarseDecoder(start_filters*2^(depth-shared_depth), out_channels_image, k_center; depth=depth-shared_depth, blocks=blocks[2], residual=residual, concat=concat, norm=Flux.BatchNorm, use_att=true)
    refinement = Refinement(; in_channels=4, out_channels=3, shared_depth=1, n_cff=k_refine, n_skips=n_skips)
    SLBR{typeof(encoder), typeof(shared_decoder), typeof(coarse_decoder), typeof(refinement)}(encoder, shared_decoder, coarse_decoder, refinement, long_skip)
end

function (m::SLBR)(synthesized::AbstractArray{Float32, 4})
    # @btime $m.encoder($synthesized) # (type stablity)-> 970.384 μs (2239 allocations: 109.50 KiB)
    @timeit to "1.0" m.encoder(synthesized)
    image_code, before_pool = m.encoder(synthesized)
    @show size(image_code), size(before_pool[1])
    unshared_before_pool = before_pool

    # @code_warntype m.shared_decoder(image_code) # 2.848 ms (7389 allocations: 467.92 KiB) (type stablity)-> 2.858 ms (7309 allocations: 464.67 KiB)
    @timeit to "2.0" m.shared_decoder(image_code)
    im, mask0 = m.shared_decoder(image_code)
    @show size(im), size(mask0)
    # @btime $m.coarse_decoder($im, nothing, $mask, $unshared_before_pool) # 233.433 ms (24298 allocations: 1.17 MiB) (type stablity)->234.892 ms (23475 allocations: 1.15 MiB)
    @timeit to "3.0" m.coarse_decoder(im, nothing, mask0, unshared_before_pool)
    ims, mask, wm = m.coarse_decoder(im, nothing, mask0, unshared_before_pool)
    @show size(ims[end]), size(mask[end])
    im = ims[end]

    reconstructed_image = Flux.tanh_fast.(im)

    if m.long_skip
        reconstructed_image = reconstructed_image + synthesized
        reconstructed_image = clamp.(reconstructed_image, zero(eltype(reconstructed_image)), one(eltype(reconstructed_image)))
    end
    reconstructed_mask = mask[end]
    reconstruct_wm = wm

    dec_feats = reverse(ims[1:end-1])
    #@show eltype(reconstructed_image)
    #@show eltype(reconstructed_mask)
    #@show eltype(synthTimerOutPutsesized)
    coarser = reconstructed_image .* reconstructed_mask + (one(eltype(reconstructed_mask)) .- reconstructed_mask) .* synthesized
    # @btime m.refinement($synthesized, $coarser, $reconstructed_mask, nothing, $dec_feats) # 664.924 ms (24094 allocations: 1.45 MiB)
    @timeit to "4.0" m.refinement(synthesized, coarser, reconstructed_mask, nothing, dec_feats)
    refine_bg = m.refinement(synthesized, coarser, reconstructed_mask, nothing, dec_feats)
    @show size(refine_bg)
    refine_bg = clamp.(Flux.tanh_fast.(refine_bg ) + synthesized, zero(eltype(refine_bg)), one(eltype(refine_bg)))

    return [refine_bg, reconstructed_image], mask, [reconstruct_wm]
end
Flux.@functor SLBR

m = SLBR(; shared_depth=2, blocks=[3 for i in 1:5], long_skip=true, k_center=2) |> gpu
@show "=================================="
# BSON.@save "model.bson" m

const x = rand(Float32, 256,256, 3, 1) |> gpu

# @btime m(x) 
# @code_warntype m(x)
out = m(x);
reset_timer!(to)
# for i in 1:10
#     out = m(x);
# end
m(x)
# @show length(out)
# @show size(out[1][1])
# @show size(out[2][end])

show(to)