Need help in the following code

import csv
import datetime
import os
from dataclasses import dataclass
from datetime import timedelta
from functools import partial
import json
import logging
import subprocess
import zipfile
from pathlib import Path
from tempfile import NamedTemporaryFile, TemporaryDirectory
from typing import TextIO, Any, List

import boto3
import smart_open
from botocore.client import BaseClient
from smart_open import register_compressor
from sqlalchemy import create_engine
from sqlalchemy.orm import sessionmaker

from airflow_utils_plugin.plugin.base_classes import (
WithDefaultValues,
DefaultValue,
)
from airflow_utils_plugin.plugin.profiling import check_memory
from airflow_utils_plugin.plugin.s3 import copy_file_in_s3

TEMP_DIR = “/tmp/working”
DEFAULT_FETCH_SIZE = 25000
BEAT_TIME = timedelta(minutes=5)
ROW_COUNT = 0
MAX_MEMORY = 0
MULTICHAR_DELIMITER_MARKER = “\t”
RAW_DATE_FORMAT = “%Y-%m-%d %H:%M:%S”

@dataclass
class ExtractMeta:
driver: str
cursor_description: Any

@dataclass(eq=False)
class ExtractDefinition(WithDefaultValues):
filename: str
uncompressed_filename: str
extract_copy: list
include_header: bool
date_format: str
quotechar: str
delimiter: str = DefaultValue(",")
escapechar: str = DefaultValue("\")
eol: str = DefaultValue("\n")
encoding: str = DefaultValue(“utf-8”)
encoding_errors: str = DefaultValue(“strict”)
extract_type: str = DefaultValue(“CSV”)

def __eq__(self, other):
    return (self.include_header == other.include_header
            and self.date_format == other.date_format
            and self.quotechar == other.quotechar
            and self.delimiter == other.delimiter
            and self.escapechar == other.escapechar
            and self.eol == other.eol
            and self.encoding == other.encoding
            and self.encoding_errors == other.encoding_errors
            and self.extract_type == other.extract_type
            and Path(self.filename).suffix not in ['.rar', '.zip']
            and Path(other.filename).suffix not in ['.rar', '.zip']
            and Path(self.filename).suffix == Path(other.filename).suffix
            )

def log_memory_usage(f_name):
global MAX_MEMORY
mem = check_memory()
if mem > MAX_MEMORY:
MAX_MEMORY = mem
logging.info(
f"{f_name}: memory usage (MB): {mem:,.2f} (max: {MAX_MEMORY:,.2f})"
)

def stringify(s, delimiter=None):
“”"
Simple function that turns most data types into a string

:param s:
:param delimiter:
:return:
"""
if s is None:
    return ""
if isinstance(s, str):
    return format_string(s, delimiter)
return str(s)

def format_date(s, date_format):
# this date comes from an intermediate raw file where dates are just strings
# in format RAW_DATE_FORMAT
try:
d = datetime.datetime.strptime(s, RAW_DATE_FORMAT)
return d.strftime(date_format)
except (AttributeError, TypeError, ValueError):
return stringify(s)

def format_number(n):
if n is None:
return “”
s = str(n)
if s[:2] == “0.”:
return s[1:] # strip leading 0
if s[:3] == “-0.”:
return “-” + s[2:]
return s

def format_string(s, delimiter=None):
new_s = s.replace("\n", “”)
new_s = new_s.replace(MULTICHAR_DELIMITER_MARKER, “”)
if delimiter:
new_s = new_s.replace(delimiter, “”)
return new_s

class CsvMultiDelimiterHandler:
def init(self, handle: TextIO, separator):
self.stream = handle
self.separator = separator

def __next__(self):
    for row in self.stream:
        groups = row.split(self.separator)
        groups = [
            '"' + c.replace('"', "").strip().replace("\x00", "") + '"'
            for c in groups
        ]
        yield ",".join(groups)

def write(self, row):
    vec = row.split(MULTICHAR_DELIMITER_MARKER)
    out = self.separator.join(vec)
    self.stream.write(out)

class TypedWriter:
def init(
self,
f,
driver,
cursor_description,
date_format,
output_delimiter,
**kwargs,
):
self.writer = csv.writer(f, **kwargs)
self.date_format = date_format
self.driver = driver
self.cursor_description = cursor_description
# we use the output_delimiter to determine whether to strip occurrences
# of the delimiter from strings in the output, we’ll only do this if
# no quotechar is provided - i.e. we ASSUME that if we have a quotechar
# then the consumer of the extract can handle the delimiter appearing
# within a field value
self.output_delimiter = (
None if kwargs.get(“quotechar”, None) else output_delimiter
)
self.formatters = self._formatters()

def _oracle_formatters(self, formatters):
    column_type_names = [e[1].name for e in self.cursor_description]
    logging.info(f"column_type_names: {column_type_names}")
    for i, name in enumerate(column_type_names):
        if name == "DB_TYPE_DATE":
            formatters[i] = partial(
                format_date, date_format=self.date_format
            )
        elif name == "DB_TYPE_NUMBER":
            formatters[i] = format_number
    return formatters

def _postgres_formatters(self, formatters):
    column_type_codes = [e[1] for e in self.cursor_description]
    logging.info(f"column_type_codes: {column_type_codes}")
    # FIXME: this code is untested as haven't come across a case yet
    # where we have a query returning a DATE-ish column that needs
    # to be formatted accordingly, so for the moment we'll just log
    # some stuff
    # for i, code in enumerate(column_type_codes):
    #     if code == "DB_TYPE_DATE":
    #         formatters[i] = partial(
    #             format_date, date_format=self.date_format
    #         )
    for column in self.cursor_description:
        logging.info(f"FIXME: _postgres_formatters: {column}")
    logging.info(
        "FIXME: _postgres_formatters: all columns currently treated "
        "as strings"
    )
    return formatters

def _formatters(self):
    formatters = [
        partial(stringify, delimiter=self.output_delimiter)
    ] * len(self.cursor_description)
    if self.date_format:
        # we need to check if we have any columns of type: DB_TYPE_DATE
        # because if we do we're going to need to override the output
        # to match the specified date_format
        if self.driver == "cx_oracle":
            formatters = self._oracle_formatters(formatters=formatters)
        elif self.driver == "psycopg2":
            formatters = self._postgres_formatters(formatters=formatters)
        else:
            logging.warning(
                "No function provided to determine column types "
                f"from cursor for driver: {self.driver}"
            )
    # if f is a partial then it has no property __name__ that we can access
    # here
    logging.info(f"formatters: {[f for f in formatters]}")
    return formatters

def writerow(self, row):
    self.writer.writerow(
        [self.formatters[i](col) for i, col in enumerate(row)]
    )

def writerows(self, rows):
    for row in rows:
        self.writerow(row)

def _handle_lzma(file_obj, mode):
import lzma

return lzma.LZMAFile(filename=file_obj, mode=mode, format=lzma.FORMAT_XZ)

register_compressor(
“.7z”, _handle_lzma
) # indenting is correct here - we call this when module imported

register_compressor(
“.xz”, _handle_lzma
) # indenting is correct here - we call this when module imported

def execute_shell_command(args):
logging.info(msg=f"Executing: f{args}")
try:
output = subprocess.check_output(args=args)
res = 0
except subprocess.CalledProcessError as e:
res = e.returncode
output = e.output
if type(output) is not list:
output = [output]
return res, [l.decode(“utf-8”).strip() for l in output]

def write_csv_file(
file_handle,
meta: ExtractMeta,
extract_definition: ExtractDefinition,
raw_data_file_name: str,
):
if len(extract_definition.delimiter) > 1:
multi_char_reader_out = CsvMultiDelimiterHandler(
file_handle, separator=extract_definition.delimiter
)
csv_writer = TypedWriter(
multi_char_reader_out,
driver=meta.driver,
cursor_description=meta.cursor_description,
date_format=extract_definition.date_format,
output_delimiter=extract_definition.delimiter,
delimiter=MULTICHAR_DELIMITER_MARKER,
quotechar=extract_definition.quotechar,
escapechar=extract_definition.escapechar,
quoting=csv.QUOTE_NONNUMERIC
if extract_definition.quotechar
else csv.QUOTE_NONE,
lineterminator=extract_definition.eol,
)
else:
csv_writer = TypedWriter(
file_handle,
driver=meta.driver,
cursor_description=meta.cursor_description,
date_format=extract_definition.date_format,
output_delimiter=extract_definition.delimiter,
delimiter=extract_definition.delimiter,
quotechar=extract_definition.quotechar,
escapechar=extract_definition.escapechar,
quoting=csv.QUOTE_NONNUMERIC
if extract_definition.quotechar
else csv.QUOTE_NONE,
lineterminator=extract_definition.eol,
)

logging.info(
    "write_csv_file: include_header={}".format(
        str(extract_definition.include_header)
    )
)
row_count = 0

if extract_definition.include_header:
    logging.info("write_csv_file: adding header information")
    csv_writer.writerow([t[0] for t in meta.cursor_description])
    row_count += 1

with open(
    raw_data_file_name, mode="r", newline="", encoding="utf-8"
) as raw_data_file:
    csv_reader = csv.reader(row.replace("\0", "") for row in raw_data_file)
    for row in csv_reader:
        csv_writer.writerow(row)
        row_count += 1
        if row_count % 100000 == 0:
            log_memory_usage("write_csv_file")

file_handle.flush()
return row_count

def write_json_file(file_handle, meta: ExtractMeta, raw_data_file_name: str):
header = [t[0] for t in meta.cursor_description]
row_count = 0
with open(
raw_data_file_name, mode=“r”, newline="", encoding=“utf-8”
) as raw_data_file:
csv_reader = csv.reader(raw_data_file)
for row in csv_reader:
# map result set into dict and remove NUL == \0 values.
# explicit check for “value is not None” otherwise float values (0.0)
# would turn to None
# TODO: consider encoding each value into utf-8 explicitly.
# this might be a potential solution to make NLS_LANG work based
# on DB value.
list_of_dict = [
dict(
(
header[i],
(stringify(value)).replace("\0", “”)
if value is not None
else None,
)
for i, value in enumerate(row)
)
]
for x in list_of_dict:
file_handle.write(json.dumps(x) + “\n”)
row_count += 1
if row_count % 100000 == 0:
log_memory_usage(“write_json_file”)

logging.info(f"write_json_file: {row_count} rows written")

return row_count

def _log_extract_methodology(extract_definition):
logging.info(
f"extracted data will be encoded as {extract_definition.encoding}, "
f’response to errors will be “{extract_definition.encoding_errors}”’
)

def unload_raw_data(conn_uri, sql, fetch_size, raw_data_file) → ExtractMeta:
os.environ[“NLS_LANG”] = “.AL32UTF8”
engine = create_engine(conn_uri)
db_session = sessionmaker(bind=engine)
session = db_session()
conn = session.connection()
logging.info(f"unload_raw_data: running query:\n{sql}")

# if this is an oracle connection then we're going to query
# using the cursor object so we can tune some settings for
# cx_oracle driver
if engine.driver == "cx_oracle":
    logging.info(
        f"(oracle) DPI_DEBUG_LEVEL={os.environ.get('DPI_DEBUG_LEVEL', 'NOT SET')}"  # noqa E501
    )
    cur = conn.connection.cursor()
    logging.info(f"(oracle) setting cursor.prefetchrows={fetch_size + 1}")
    cur.prefetchrows = fetch_size + 1
    logging.info(f"(oracle) setting cursor.arraysize={fetch_size}")
    cur.arraysize = fetch_size
    result = cur.execute(sql)
    cursor_description = result.description
else:
    result = conn.execute(sql)
    cursor_description = result.cursor.description

logging.info("unload_raw_data: done executing sql")
logging.info("unload_raw_data: starting dumping data to temporary file")
writer = csv.writer(raw_data_file)
writer.writerows(result)  # write data
raw_data_file.flush()
logging.info("unload_raw_data: finished dumping data to temporary file")
return ExtractMeta(
    driver=engine.driver, cursor_description=cursor_description,
)

def write_extract_file(
meta: ExtractMeta,
extract_definition: ExtractDefinition,
raw_data_file_name: str,
client: BaseClient,
):
file_path = Path(extract_definition.filename)
logging.info(f"write_extract_file: {file_path}")

suffix = file_path.suffix.lower()
if suffix == ".rar":
    row_count = rar_extract(
        meta=meta,
        extract_definition=extract_definition,
        raw_data_file_name=raw_data_file_name,
        client=client,
    )
elif suffix == ".zip":
    row_count = zip_extract(
        meta=meta,
        extract_definition=extract_definition,
        raw_data_file_name=raw_data_file_name,
        client=client,
    )
else:  # .gz, .bz2, .7z, .xz, uncompressed
    row_count = stream_extract(
        meta=meta,
        extract_definition=extract_definition,
        raw_data_file_name=raw_data_file_name,
        client=client,
    )
    for target_file_name in extract_definition.extract_copy:
        parts = extract_definition.filename.split('/')
        bucket = parts[2]
        source_key = '/'.join(parts[3:])
        target_key = '/'.join(target_file_name.split('/')[3:])
        copy_file_in_s3(bucket,source_key,target_key)


logging.info(f"write_extract_file: {row_count:,} rows written")
return row_count

def stream_extract(
meta: ExtractMeta,
extract_definition: ExtractDefinition,
raw_data_file_name: str,
client: BaseClient,
):
_log_extract_methodology(extract_definition=extract_definition)
with smart_open.open(
extract_definition.filename,
“w”,
transport_params={“client”: client},
encoding=extract_definition.encoding,
errors=extract_definition.encoding_errors,
) as file_handle:
row_count = extract_to_file(
file_handle=file_handle,
meta=meta,
extract_definition=extract_definition,
raw_data_file_name=raw_data_file_name,
)
return row_count

def extract_to_file(
file_handle,
meta: ExtractMeta,
extract_definition: ExtractDefinition,
raw_data_file_name: str,
):
row_count = 0
if extract_definition.extract_type.upper() == “CSV”:
logging.info(“extracting data into CSV format”)
row_count = write_csv_file(
file_handle=file_handle,
meta=meta,
extract_definition=extract_definition,
raw_data_file_name=raw_data_file_name,
)
elif extract_definition.extract_type.upper() == “JSON”:
logging.info(
“extracting data into JSON format is unsupported in this version”
)
return row_count

def validate_extract_definitions(extract_definitions: List[ExtractDefinition]):
for ed in extract_definitions:
if ed.extract_type.upper() not in [“CSV”, “JSON”]:
raise ValueError(
f’Unrecognized value “{ed.extract_type}” for extract_format. ’
"Expected one of [CSV, JSON]. "
“Unable to proceed.”
)

def create_extract(conn_uri, sql, fetch_size, extract_definitions) → int:
“”"
Extract data into CSV file using SQLAlchemy abstraction

:param conn_uri: SQLAlchemy connection URI
:param sql: query text to execute
:param fetch_size: number of records to fetch each time we access
                   cursor (load into memory)
:param extract_definitions: singular or list of ExtractDefinition
        dataclasses defining the required extracts to be created
:return: row_count
"""
row_count = 0
if isinstance(extract_definitions, ExtractDefinition):
    extract_definitions = [extract_definitions]

validate_extract_definitions(extract_definitions)

# ensure working directory exists
Path(TEMP_DIR).mkdir(parents=True, exist_ok=True)

# log a baseline for memory
log_memory_usage("create_extract")

# we create a boto3 S3 client here so that it can be shared among
# threads - in theory this will ensure we avoid any thread-safety
# issues with boto3 session object
session = boto3.Session()
client = session.client("s3")

with NamedTemporaryFile(
    dir=TEMP_DIR, mode="w", newline="", encoding="utf-8"
) as raw_data_file:
    meta = unload_raw_data(
        conn_uri=conn_uri,
        sql=sql,
        fetch_size=fetch_size,
        raw_data_file=raw_data_file,
    )

    # there's no point running the exact same extract more than once
    # sometimes we specify the same extract multiple times because the
    # same output is sent to different customers
    extracts = []
    for extract_definition in extract_definitions:
        extract_copy = False
        for extract_definition_copy in extracts:
            if extract_definition == extract_definition_copy:
                extract_definition_copy.extract_copy.append(extract_definition.filename)
                extract_copy = True
                break

        if extract_copy:
            continue

        extracts.append(extract_definition)

    for extract_definition in extracts:
        logging.info(extract_definition)
        row_count = write_extract_file(
            meta=meta,
            extract_definition=extract_definition,
            raw_data_file_name=raw_data_file.name,
            client=client,
        )

return row_count

def zip_extract(
meta: ExtractMeta,
extract_definition: ExtractDefinition,
raw_data_file_name: str,
client: BaseClient,
):
# NamedTemporaryFile does not support “errors” parameter until 3.8 so
# we’re using TemporaryDirectory to hold a named file here
_log_extract_methodology(extract_definition=extract_definition)
with TemporaryDirectory(dir=TEMP_DIR) as tempdir:
extract_filename = (
Path(tempdir) / extract_definition.uncompressed_filename
)
with open(
extract_filename,
“w”,
encoding=extract_definition.encoding,
errors=extract_definition.encoding_errors,
) as extract_file:
with NamedTemporaryFile(dir=TEMP_DIR) as archive_file:
# create the extract file
row_count = extract_to_file(
file_handle=extract_file,
meta=meta,
extract_definition=extract_definition,
raw_data_file_name=raw_data_file_name,
)
# add file to zip archive
zf = zipfile.ZipFile(
archive_file.name, “w”, zipfile.ZIP_DEFLATED
)
zf.write(
extract_file.name, extract_definition.uncompressed_filename
)
zf.close()
# copy archive to where it needs to go
with smart_open.open(archive_file.name, “rb”,) as in_file:
with smart_open.open(
extract_definition.filename,
“wb”,
transport_params={“client”: client},
) as out_file:
for line in in_file:
out_file.write(line)
return row_count

def rar_extract(
meta: ExtractMeta,
extract_definition: ExtractDefinition,
raw_data_file_name: str,
client: BaseClient,
):
_log_extract_methodology(extract_definition=extract_definition)
with TemporaryDirectory(dir=TEMP_DIR) as tempdir:
extract_file = Path(tempdir) / extract_definition.uncompressed_filename
archive_file = Path(tempdir) / Path(extract_definition.filename).name
logging.info(f"extract_file: {extract_file}")
logging.info(f"archive_file: {archive_file}")
# create the extract file
with open(
extract_file,
“w”,
encoding=extract_definition.encoding,
errors=extract_definition.encoding_errors,
) as file_handle:
row_count = extract_to_file(
file_handle=file_handle,
meta=meta,
extract_definition=extract_definition,
raw_data_file_name=raw_data_file_name,
)
# add the file to the rar archive
args = [
“rar”,
“a”,
“-ep”,
“-dh”,
archive_file.as_posix(),
extract_file.as_posix(),
]
res, output = execute_shell_command(args=args)
if res != 0:
s = “\n”.join(output)
c = " “.join(args)
raise Exception(f"Error: {res} returned by ‘{c}’ :\n{s}”)
# copy archive to where it needs to go
with smart_open.open(archive_file.as_posix(), “rb”,) as in_file:
with smart_open.open(
extract_definition.filename,
“wb”,
transport_params={“client”: client},
) as out_file:
for line in in_file:
out_file.write(line)
return row_count

Have this code basically what it does is it generates DAGs to extract data based on input passed in a Json file, I need to amend this in a way that if two customers receiving same file(same means everything identical like file separator,filename,encoding etc) then instead of re generating the file it should just copy the first generated file, I have done that part, but now I need to deal with .zip,.rar.gz extensions as well where am not able to figure out, if some one helps me on this that would be a great help for me.
Thanks in advance