--[[*********************************************************************************************************\

Module Name:    init.lua

Description:    

References:     

    Copyright (c) 2015, Matrox Graphics Inc.
    All Rights Reserved.

\***********************************************************************************************************]]

-- -----------------------------------------------------------------------------------------------------------
--                                    I N C L U D E S   A N D   U S I N G S
-- -----------------------------------------------------------------------------------------------------------

-- includes
local h264utils = require("H264Utils")
-- externals (global variables used by the module
local assert = assert
local debug = debug
local io = io
local os = os
local pairs = pairs
local print = print
local string = string
local type = type
local tostring = tostring
local tonumber = tonumber
local NALUTYPE_UNSPECIFIED = NALUTYPE_UNSPECIFIED
local NALUTYPE_SLICE = NALUTYPE_SLICE
local NALUTYPE_DPA = NALUTYPE_DPA
local NALUTYPE_DPB = NALUTYPE_DPB
local NALUTYPE_DPC = NALUTYPE_DPC
local NALUTYPE_IDR = NALUTYPE_IDR
local NALUTYPE_SEI = NALUTYPE_SEI
local NALUTYPE_SPS = NALUTYPE_SPS
local NALUTYPE_PPS = NALUTYPE_PPS
local NALUTYPE_AUD = NALUTYPE_AUD
local NALUTYPE_EOSEQ = NALUTYPE_EOSEQ
local NALUTYPE_EOSTREAM = NALUTYPE_EOSTREAM
local NALUTYPE_FILL = NALUTYPE_FILL
local NALUTYPE_SPSE = NALUTYPE_SPSE
local NALUTYPE_PREFIX_NAL = NALUTYPE_PREFIX_NAL
local NALUTYPE_SSPS = NALUTYPE_SSPS
local NALUTYPE_AUXC = NALUTYPE_AUXC
local NALUTYPE_SEXT = NALUTYPE_SEXT

-- -----------------------------------------------------------------------------------------------------------
--                                    C O N S T A N T S   A N D   T Y P E S
-- -----------------------------------------------------------------------------------------------------------

local verbose = 0

-- makes global variables unreachable and create a module for H264Utils
local module = {}

-- rawget is to avoid the _G metatable, we test if this is lua pre-5.2 were setfenv
-- is used to affect the environment.
if rawget(_G, "setfenv") then
    setfenv(1, module) -- for 5.1
else
    _ENV = module -- for 5.2
end

-- constants using externals
local nalu_types =
{
    [NALUTYPE_UNSPECIFIED] = "NALUTYPE_UNSPECIFIED",
    [NALUTYPE_SLICE] = "NALUTYPE_SLICE",
    [NALUTYPE_DPA] = "NALUTYPE_DPA",
    [NALUTYPE_DPB] = "NALUTYPE_DPB",
    [NALUTYPE_DPC] = "NALUTYPE_DPC",
    [NALUTYPE_IDR] = "NALUTYPE_IDR",
    [NALUTYPE_SEI] = "NALUTYPE_SEI",
    [NALUTYPE_SPS] = "NALUTYPE_SPS",
    [NALUTYPE_PPS] = "NALUTYPE_PPS",
    [NALUTYPE_AUD] = "NALUTYPE_AUD",
    [NALUTYPE_EOSEQ] = "NALUTYPE_EOSEQ",
    [NALUTYPE_EOSTREAM] = "NALUTYPE_EOSTREAM",
    [NALUTYPE_FILL] = "NALUTYPE_FILL",
    [NALUTYPE_SPSE] = "NALUTYPE_SPSE",
    [NALUTYPE_PREFIX_NAL] = "NALUTYPE_PREFIX_NAL",
    [NALUTYPE_SSPS] = "NALUTYPE_SSPS",
    [NALUTYPE_AUXC] = "NALUTYPE_AUXC",
    [NALUTYPE_SEXT] = "NALUTYPE_SEXT",
}

local sei_type1_data =
"0605FFF2F7493EB3D40047968686C9707B64372A554D494413FF0000FF0000FF14FF0000FF0000FF60FFFFFFFF22FFFFFFFFFFFFFFFF" ..
"FFFFFFFFFFFFFFFFFFFF62FF0000FF0000FF63FF0000FF0000FFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFF" ..
"FFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFF" ..
"FFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFF" ..
"FFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFF" ..
"FFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFF" ..
"FFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFF" ..
"FFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFF80"

-- -----------------------------------------------------------------------------------------------------------
--                                                   C O D E
-- -----------------------------------------------------------------------------------------------------------

function string_fromhex(str)
    return (str:gsub('..', function (cc)
        return string.char(tonumber(cc, 16))
    end))
end

--[[*********************************************************************************************************\

Function:       _pp

Description:    

Parameters:     

Comments:       
 
************************************************************************************************************]]
function _pp(obj, indent, path)
    local indent = (ident or 0)
    local path = (path or {})
    local str = ""
    local prefix = string.rep("  ", indent)
    if type(obj) == "table" then
        str = str .. prefix .. "{\n"
        indent = indent+1
        prefix = string.rep(" ", indent)
        for k,v in pairs(obj) do
            local r = ""
            if not path[v] then
                path[v] = k
                r = prefix .. tostring(k) .. " = " .. _pp(v, indent+1, path)
            else
                r = prefix .. tostring(k) .. " = " .. path[v]
            end
            str = str .. r .. ",\n"
        end
        indent = indent-1
        prefix = string.rep(" ", indent)
        str = str .. prefix .. "}"
    else
        str = prefix .. tostring(obj)
    end
    return str
end

--[[*********************************************************************************************************\

Function:       pp

Description:    

Parameters:     obj

Comments:       
 
************************************************************************************************************]]
function pp(obj)
    print(_pp(obj))
end


--[[*********************************************************************************************************\

Function:       dump_nal

Description:    

Parameters:     output
                startcode_size
                payload

Comments:       
 
************************************************************************************************************]]
local function dump_nal(output, startcode_size, payload)
    if startcode_size == 4 then
        start_code = string.char(0, 0, 0, 1)
    else
        start_code = string.char(0, 0, 1)
    end
    output:write(start_code)
    output:write(payload)
end

--[[*********************************************************************************************************\

Function:       apply_padding

Description:    

Parameters:     output
                padding

Comments:       
 
************************************************************************************************************]]
local function apply_padding(output, padding)
    if padding > 0 then
        -- jump ahead the unset are expected to be filled with zeroed by os
        output:seek("cur", padding);
    end
end

--[[*********************************************************************************************************\

Function:       module.patch

Description:    

Parameters:     class
                options

Comments:       
 
************************************************************************************************************]]
function module.patch(class, options)
    local class_type, class_number, str_width, str_height, scanmode =
        class:match("avc_(.-)_(class%d-)_(%d-)x(%d-)([ip])")
    assert(class_type)

    local tmp_input_fn = assert(options.output, "The output filename must pass specified") .. ".tmp"
    os.remove(tmp_input_fn)
    os.rename(options.output, tmp_input_fn)
    local file = h264utils.open(tmp_input_fn)
    local output = io.open(options.output, "wb")

    if (verbose > 0) then
        print("class_type=" .. class_type .. " class_number=" .. class_number)
        print(" str_width=" .. str_width .. " str_height=" .. str_height .. " scanmode=" .. scanmode)
    end
    print("patching " .. options.output)

    local nalu_num = 0
    local bytes_count = 0

    local sei_padding_size
    if tonumber(str_height) >= 1080 then
        sei_padding_size = 18*512
    else
        sei_padding_size = 10*512
    end

    local sei_data = string_fromhex(sei_type1_data)
    local sei_data_size = string.len(sei_data)

    local last_nalu_type    = NALUTYPE_UNSPECIFIED
    for nalu in file.each_nalu, file  do
        local ntype          = nalu:type()
        local startcode_size = nalu:start_code_type()
        local payload        = nalu:data()
        local payload_size   = #payload
        local ttl_size       = startcode_size + payload_size
        bytes_count          = bytes_count + ttl_size
        nalu_num             = nalu_num + 1

        --local insert_seis    = false
        local padding_bytes  = 0
        if (ntype == NALUTYPE_SPS) then
            padding_bytes = 256 - ttl_size - 6
        elseif (ntype == NALUTYPE_PPS) then
            padding_bytes = 256 - ttl_size
        elseif (ntype == NALUTYPE_IDR) then
            if (last_nalu_type == NALUTYPE_AUD) then
                -- insert and pad seis
                local sei_zero_bytes = sei_padding_size - sei_data_size - startcode_size - 6
                assert(sei_zero_bytes >= 0)
                dump_nal(output, startcode_size, sei_data)
                apply_padding(output, sei_zero_bytes)
            elseif (last_nalu_type == NALUTYPE_PPS) then
                -- insert filler
                filler = string.char(0xC, 0x80)
                dump_nal(output, startcode_size, filler)
                -- insert and pad seis
                local sei_zero_bytes = sei_padding_size - sei_data_size - startcode_size - 6
                assert(sei_zero_bytes >= 0)
                dump_nal(output, startcode_size, sei_data)
                apply_padding(output, sei_zero_bytes)
            end
        end
        assert(padding_bytes >= 0)

        dump_nal(output, startcode_size, payload)
        apply_padding(output, padding_bytes)

        -- Set next iteration values
        last_nalu_type = ntype
        if verbose > 0 then
            print("nalu #" .. nalu_num .. " type:" .. nalu_types[ntype] .. " ttl_size:" .. ttl_size .." bc:" .. bytes_count)
        end
    end
    output:close()

    -- clean up the original "input" file
    os.remove(tmp_input_fn)
    print("done patching " .. options.output)
end

-- standalone test expects a output.264 in the current folder which is a AVC 1920x1080p CBG file missing the PPS, SPS and SEI padding
standalone =
    (debug.getinfo(1) and debug.getinfo(1).what == "main") and
    ((not debug.getinfo(2)) or (debug.getinfo(2).name ~= "require"))
--pp(debug.getinfo(1))
--pp(debug.getinfo(2))
--print(standalone)

if standalone then
    print("Unit test: pad output.264 as if it was a 1080p intra only AVC file.");
    os.execute("cp -f output.264 output2.264")
    module.patch(
        "avc_intra_class100_1920x1080p_60Mfps",
    {
        output = "output2.264",
    })
end

return module
