Source code for edge_pydb.hexgrid

import numpy as np
from astropy.table import Table, Column
from edge_pydb.conversion import gc_polr
from functools import wraps

'''
Interpolate the data onto a hexagonal grid.  This code is still in a
rough state.  To be done:
(1) Support for cubes, for which there is an additional coordinate
    column iz
(2) The ra_abs, dec_abs, ra_off, and dec_off columns are still determined
    by interpolation; they should be calculated from ix and iy from the WCS
'''

use_numba = False

if use_numba:
    use_python = False
    try: 
        from numba import jit, njit, prange
        from numba.typed import List
    except (ImportError, ModuleNotFoundError):
        print("cannot import numba modules, computation will be slow")
        use_python = True
else:
    use_python = True

if use_python:
    prange = range
    List = list
[docs] def py_njit(parallel=False): def py_njit_func(func): @wraps(func) def wrapper(*args, **kwargs): retval = func(*args, **kwargs) return retval return wrapper return py_njit_func
[docs] def py_jit(func): return func
jit = py_jit njit = py_njit
[docs] @jit def ylin_hex(pos, hex_sidelen, bound): ''' going along y dir ''' # print(pos, hex_sidelen, bound) dist = np.sqrt(3) * hex_sidelen y_plus = np.arange(pos[1], bound[1][1], dist) num = np.floor((pos[1] - bound[0][1]) / dist) y_minus = np.arange(pos[1] - num * dist, pos[1], dist) y_coord = np.concatenate([y_minus, y_plus]) return np.array(list(zip(np.ones_like(y_coord) * pos[0], y_coord)))
[docs] @jit def hex_basis(ref, hex_sidelen, bound): dist = np.sqrt(3) * hex_sidelen angle = np.pi/6 delta = np.array((dist * np.cos(angle), dist * np.sin(angle))) basis_pos = [] ref = np.array((float(ref[0]), float(ref[1]))) plus = np.copy(ref) # print(bound[1][0]) while plus[0] < bound[1][0] and plus[1] < bound[1][1]: basis_pos.append(np.copy(plus)) # print(plus) plus += delta minus = np.copy(ref) - delta while minus[0] > bound[0][0] and minus[1] > bound[0][1]: basis_pos.append(np.copy(minus)) minus -= delta return np.array(basis_pos)
[docs] @jit def hex_grid(ref, sidelen, bound, starting_angle, precision): # hex side length same as the circumcircle Radius, = 2 / \sqrt(3) * incircle radius x_len = bound[1][0] - bound[0][0] y_len = bound[1][1] - bound[0][1] if x_len > y_len: max_len = x_len else: max_len = y_len # create a maximum square rotate_compen = np.array((max_len * (1 + np.cos(starting_angle)), max_len * (1 + np.cos(starting_angle)))) max_bound = [bound[0] - rotate_compen, bound[0] + rotate_compen] x_basis = hex_basis(ref, sidelen, max_bound) grid = None for point in x_basis: tmp = ylin_hex(point, sidelen, max_bound) if grid is None: grid = np.copy(tmp) else: grid = np.concatenate((grid, tmp), axis=0) rotation_mtx = np.array([[np.cos(starting_angle), np.sin(starting_angle)], [-np.sin(starting_angle), np.cos(starting_angle)]]) grid = np.dot(grid, rotation_mtx) cut = (grid[:, 0] > bound[0, 0]) & (grid[:, 1] > bound[0, 1]) \ & (grid[:, 0] < bound[1, 0]) & (grid[:, 1] < bound[1, 1]) # There will be repeated data points if we don't remove them # not_repeated = np.full(len(grid[cut]), False) # for i, coord in enumerate(grid[cut]): # occurence = 0 # for to_compare in grid[cut]: # if np.abs(coord[0] - to_compare[0]) < 1e-5 and np.abs(coord[1] - to_compare[1]) < 1e-5: # occurence += 1 # if occurence <= 1: # not_repeated[i] = True not_repeated = [True if np.where((np.abs(coord - grid) < 1e-5).all(axis=1))[0].shape[0] <= 1 else False for coord in grid[cut]] repeated_x = [] repeated_y = [] for i in range(len(not_repeated)): if not_repeated[i] == False \ and grid[cut][i][0] not in repeated_x \ and grid[cut][i][1] not in repeated_y: repeated_x.append(grid[cut][i][0]) repeated_y.append(grid[cut][i][1]) not_repeated[i] = True if precision != 0: grid = np.around(grid, precision) grid_cut_not_repeated = grid[cut][not_repeated] # sort the result arrays based on ix, iy return grid_cut_not_repeated[np.lexsort((grid_cut_not_repeated[:, 1], grid_cut_not_repeated[:, 0]))]
[docs] @jit def interpolate_neighbor(point, bound, step_size=0): ''' find the adjacent pixels of the point to interpolate from if the point is directly on one of the pixel, we take that value directly ''' if step_size == 0: tmp = List([np.ceil(point), np.floor(point), np.array([np.floor(point[0]), np.ceil(point[1])]), np.array([np.ceil(point[0]), np.floor(point[1])])]) else: i = -1 * step_size j = i tmp = List() cen = np.ceil(point) while i != step_size: while j != step_size: tmp.append(np.array([cen[0]+i, cen[1]+j])) # print(np.array([cen[0]+i, cen[1]+j])) j += 1 i += 1 flag = False for k in range(len(tmp)): coord = tmp[k] if bound[0][0] <= coord[0] <= bound[1][0] and bound[0][1] <= coord[1] <= bound[1][1]: if not flag: # retval.append(coord) retval = np.reshape(coord, (-1, 2)) flag = True else: retval = np.concatenate((retval, np.reshape(coord, (-1, 2)))) return retval
# note that if useing numba, there might be an error caused by mixed use of float32 and float64 # when casting float64 to float32, memory can overflow.
[docs] @njit(parallel=True) def interpolate_all_points(tab, datapoint, bound, header): sampled_tab = np.zeros((datapoint.shape[0], len(header)), dtype=np.float32) for j in prange(datapoint.shape[0]): coord = datapoint[j] inter = interpolate_neighbor(coord, bound) weight_arr = np.zeros(inter.shape[0]) on_original_pix = False idx = 0 for i, point in enumerate(inter): cur = tab[(tab['ix'] == point[0]) & (tab['iy']==point[1])] distance = np.linalg.norm(point - coord) if distance == 0: idx = i on_original_pix = True break else: weight_arr[i] = 1. / distance # use a weight for each column if on_original_pix: row = tab[(tab['ix'] == inter[idx][0]) & (tab['iy']==inter[idx][1])] else: w = np.ones((len(weight_arr), len(header[2:]))) for i in np.arange(len(weight_arr)): w[i, :] *= np.float32(weight_arr[i] / np.sum(weight_arr)) row = np.zeros(len(header[2:])) for i in np.arange(inter.shape[0]): # talking about the Nans in the paper cur = tab[(tab['ix'] == inter[i][0]) & (tab['iy']==inter[i][1])] row = np.sum([row, w[i, :] * cur.view(np.float32)[2:]], axis=0) sampled_tab[j, :2] = coord if type(row[0]) is np.void: row = np.array([np.float32(row[0][i]) for i in np.arange(2, len(row[0]))]) sampled_tab[j, 2:] = row # for i in range(len(header[2:])): # if type(row[i]) is np.void: # row = row[0] # # some edge cases # if row[i] == 0: # sampled_tab[j, 2+i] = np.nan # else: # sampled_tab[j, 2+i] = row[i] return sampled_tab
# @jit
[docs] def hex_sampler(tab, sidelen, keepref, ref_pix, ra_ref, dec_ref, ra_gc, dec_gc, pa, inc, starting_angle=0, precision=2, hexgrid_output=None): ''' a wrapper function to generate the hex grid and then interpolate and output the sampled data in Astropy Table precision = 0 -> full precision hexgrid_output -> output file to check ''' header = tab.colnames upper = np.array([float(max(tab[header[0]])), float(max(tab[header[1]]))]) lower = np.array([float(min(tab[header[0]])), float(min(tab[header[1]]))]) bound = np.array([lower, upper]) if keepref: cen = ref_pix else: cen = np.array([80.,80.]) # create a hex grid map datapoint= hex_grid(cen, sidelen, bound, starting_angle, precision) if hexgrid_output is not None: np.savetxt(hexgrid_output, datapoint) info = interpolate_all_points(tab.as_array(), datapoint, bound, List(header)) # Currently Astropy has trouble creating a table with dimensionless columns units = [None if tab[col].unit=='' else tab[col].unit for col in tab.colnames] sampled_tab = Table(info, names=header, units=units) if ('rad_arc' in header) and ('azi_ang' in header): sampled_tab['rad_arc'], sampled_tab['azi_ang'] = gc_polr( sampled_tab['ra_off'] + ra_ref, sampled_tab['dec_off'] + dec_ref, ra_gc, dec_gc, pa, inc ) return sampled_tab