#!/usr/bin/python3 -tt

import sys
import glob
import os
import jinja2
import time

def main(args):
    # read in a config file(
    config_opts = {}
    exec(compile(open(args[0], "rb").read(), args[0], 'exec'))

    # read in commands file for old commands
    disable = []
    cmds_file = config_opts.get('commands_file', '')
    if os.path.exists(cmds_file):
        for l in open(cmds_file,'r').readlines():
            l = l.strip()
            if not l:
                continue
            oldcmds = l.split()
            if oldcmds[0] == 'enable':
                for ip in oldcmds[1:]:
                    if ip in disable:
                        disable.remove(ip)
            if oldcmds[0] == 'disable':
                for ip in oldcmds[1:]:
                    if ip not in disable:
                        disable.append(ip)
    
    # parse commands
    if len(args) > 1:
        basecmd = args[1]
        if basecmd == 'reset':
            disable = []
            if os.path.exists(cmds_file):
                try:
                    open(cmds_file, 'w').write('')
                except (OSError, IOError):
                    print("Failed to blank: %s" % cmds_file, file=sys.stderr)
        elif basecmd in ('disable', 'enable'):
            if basecmd == 'disable':
                for ip in args[2:]:
                    if ip not in disable:
                        disable.append(ip)

            elif basecmd == 'enable':
                for ip in args[2:]:
                    if ip in disable:
                        disable.remove(ip)
            try:
                if not os.path.exists(os.path.dirname(cmds_file)):
                    os.makedirs(os.path.dirname(cmds_file))

                open(cmds_file, 'a').write(' '.join(args[1:]) + '\n')
            except:
                pass
        else:
            print("No such command: %s" % basecmd)
            sys.exit(1)
    
    # warn if we're are non standard
    if os.path.exists(cmds_file) and os.stat(cmds_file).st_size > 0:
        print("Non-default commands in place:")
        print(open(cmds_file).read())
        

    # look at our own commands
    # if reset then blank disable and unlink the commands_file
    # if disable or enable then add/remove them from the disable list
    # add the entry to the commands_file
    

    # generate zone files (per region)
    # soa == sec from the epoch
    # floor the serial to the nearest minute (in seconds)
    now = time.time()
    serial = (int(now) - int(now) % 60) +1000000000  # Date is now large that jsut adding will make sure
                                            # that is it is bigger than a serial like YYYYMMDDII
    config_opts['serial'] = serial

    by_region = {}
    ttldict = config_opts.get('ttldict', {})
    for rectype, ipdict in [('A', 'ipv4_proxy'), ('AAAA', 'ipv6_proxy')]:
        for rec in config_opts.get(rectype, []):
            ttl = ttldict.get(rec, config_opts.get('def_proxy_ttl', ''))# get a specific ttl or just blank
            for (ip,options) in list(config_opts.get(ipdict, {}).items()):
                if isinstance(options, list):
                    regions = options
                    names = []
                else:
                    regions = options['regions']
                    names = options['names']
                for region in regions:
                    msg = '%s    %s    IN        %s         %s' %(rec, ttl, rectype, ip)
                    msg2 = '%s.%s    %s    IN        %s         %s' %(ipdict.replace('_', '-'), rec, ttl, rectype, ip)
                    if region  not in by_region: by_region[region] = {}
                    if rectype not in by_region[region]: by_region[region][rectype] = []
                    # This is so that the region is always added, meaning the zone file will always be updated
                    enabled = True
                    if ip in disable:
                        enabled = False
                    if names:
                        for name in names:
                            if name in disable:
                                enabled = False
                    if enabled:
                        by_region[region][rectype].append(msg)
                        if rec != '@':
                            by_region[region][rectype].append(msg2)

    zone_templ = config_opts.get('zone_template', '')
    if not os.path.exists(zone_templ):
        print("Template %s doesn't exist - aborting" % zone_templ, file=sys.stderr)
        sys.exit(1)
    templ = open(zone_templ,'r').read()
    template = jinja2.Template(templ)
    dest = config_opts.get('destdir', os.getcwd())
    
    for reg in by_region:
        destdir = dest + '/' + reg + '/'
        if not os.path.exists(destdir):
            os.makedirs(destdir)
        destfn = destdir + config_opts.get('domain_name', '')
        fo = open(destfn, 'w')
        a_records = '\n'.join(by_region[reg]['A'])
        aaaa_records = '\n'.join(by_region[reg].get('AAAA', []))
        fo.write(template.render(serial=serial, a_records=a_records, 
                aaaa_records=aaaa_records, region=reg))
        fo.close()
    # if we're not doing the region stuff - just run it as a normal template
    if not by_region:
        destfn = dest + config_opts.get('domain_name', '')
        fo = open(destfn, 'w')
        fo.write(template.render(region='none', **config_opts))
        fo.close()
    return 0


if __name__ == '__main__':
    if len(sys.argv) < 2:
        print("Must specify a config file", file=sys.stderr)
        sys.exit(1)
    if not os.path.exists(sys.argv[1]):
        print("Could not find config %s" % sys.argv[1], file=sys.stderr)
        sys.exit(1)
    sys.exit(main(sys.argv[1:]))

