模块:Csv.lua

来自「荏苒之境」
Sicusa留言 | 贡献2025年8月2日 (六) 01:03的版本

此模块的文档可以在模块:Csv.lua/doc创建

--[[
Modifed by Phlamcenth Sicusa from:
    CSV Library v1 (Author: Michael Lutz, 2022-12-14)
    Built on: http://lua-users.org/wiki/LuaCsv
--]]

local unpack = unpack or table.unpack

local BYTE_QUOTE = string.byte('"')
local BYTE_ENTER = string.byte('\r')
local BYTE_NEWLINE = string.byte('\n')
local BYTE_COMMA = string.byte(',')

local function parse_quoted(input, sep, pos, s)
    while true do
        local c = string.byte(input, pos)
        if c == BYTE_QUOTE then
            if string.byte(input, pos + 1) == BYTE_QUOTE then
				pos = pos + 1
                s[#s+1] = '"'
            else
                return pos + 1
            end
        elseif c == sep then
			return pos + 1
		elseif c == nil then
			return pos
		else
			s[#s+1] = c
        end
		pos = pos + 1
    end
end

local function bytes_to_string(bytes)
    if #bytes == 0 then
        return ""
    else
        return string.char(unpack(bytes))
    end
end

local function parse_row(input, sep, pos)
    local r = {}
	local s = {}
    while true do
        local c = string.byte(input, pos)
        if c == sep then
			r[#r+1] = bytes_to_string(s)
			s = {}
            pos = pos + 1
        elseif c == nil then
            if #r ~= 0 or #s ~= 0 then
                r[#r+1] = bytes_to_string(s)
            end
            break
		elseif c == BYTE_NEWLINE then
			r[#r+1] = bytes_to_string(s)
            pos = pos + 1
			break
		elseif c == BYTE_ENTER then
			pos = pos + 1
        elseif c == BYTE_QUOTE then
			pos = parse_quoted(input, sep, pos + 1, s)
		else
			s[#s+1] = c
			pos = pos + 1
        end
    end
	return r, pos
end

local csv = {}

csv.parse = function(str, enable_header, delimiter)
    local sep = delimiter and string.byte(delimiter) or BYTE_COMMA
    local pos = 1
    local csv = {}
    local row_mt = nil

    if enable_header then
        local header
        header, pos = parse_row(str, sep, pos)
        csv.header = header

        local head_map = {}
		for i = 1, #header do
            head_map[header[i]] = i
        end

        row_mt = {
            __index = function (t,k)
                local i = head_map[k]
                if i then
                    return t[i]
                end
                return nil
            end
        }
    end

    local row
    row, pos = parse_row(str, sep, pos)

    while #row ~= 0 do
        if row_mt then
            setmetatable(row, row_mt)
        end
        csv[#csv+1] = row
        row, pos = parse_row(str, sep, pos)
    end

    return csv
end

local function format_csv(str, sep)
    local str, matches = string.gsub(str or "", '"', '""')
    if string.find(str, "[%"..sep.."\r\n]") or matches > 0 then
        return '"'..str..'"'
    end
    return str
end

csv.format = function(data, header, delimiter)
    local r = {}
    local sep = delimiter and string.sub(delimiter,1,1) or ','

    if header then
		for i = 1, #header do
            r[#r+1] = format_csv(header[i], sep)
            r[#r+1] = sep
        end
        r[#r] = "\n"
    end

	for i = 1, #data do
		local v = data[i]
		for j = 1, #v do
            r[#r+1] = format_csv(v[j], sep)
            r[#r+1] = sep
		end
        r[#r] = "\n"
	end

    return table.concat(r)
end

return csv