Rayleigh correction of the TOA measurements from PACE instruments#

Authors: Kamal Aryal (UMBC), Pengwang Zhai (UMBC)

Summary#

This notebook is used to do Rayleigh correction for TOA reflectances obtained from PACE instruments using a nerual network (NN) trained model.

  • The NN model is adopted from atmospheric correction module of FastMAPOL/component retrieval algorithm (Aryal et al., 2024). The rayleigh signal is obtained by setting input aerosol and surface parameters to a very low number.

  • Inputs: Viewing geometry and atmospheric parameters (Surface pressure and Ozone).

  • Neural Network Output: Rayleigh reflectance at 13 discrete wavelengths (385–870 nm).

  • Interpolation: Rayleigh reflectance is interpolated to a fine wavelength grid using the physical relation R(λ) = c / λ⁴.

  • Scaling factor ( c ): Computed by least squares fitting to match the neural network output.

This notebook highlights the importance of Rayleigh correction in atmospheric and ocean color remote sensing. It uses a neural network trained model to predict rayleigh reflectances at 13 discrete wavelengths. The NN models is adopted from FastMAPOL/component retrieval algorithm (Aryal et al., 2024). The NN model was originally designed for atmospheric correction of multiangle intensity mesurements.

Learning Objectives#

By the end of this notebook you will be familiar about:

  • The NN training process for Rayleigh correction.

  • How to use developed model to do rayleigh correction of PACE instruments measurement.

  • The importance of Rayleigh correction in ocean color remote sensing.

1. Setup#

Import all of the packages used in this notebook.

import sys
import torch
import earthaccess

sharedpath='/home/jovyan/shared-public/pace-hackweek/rayleigh-correction/'
if torch.cuda.is_available():
    device = torch.device("cuda:0") 
    print("Running on the GPU")
else:
    device = torch.device("cpu")
    print("Running on the CPU")

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import xarray as xr
import cartopy.crs as ccrs
from scipy.signal import convolve2d
from scipy import ndimage
import pickle
import types
Running on the CPU

2. Neural network model class used during training#

  • Here, we load all necessary classes and functions used during neural network model training.

  • To keep the workflow lightweight, only the essential modules required for the model’s forward computation are recreated here.

  • This avoids the need to upload the entire retrieval algorithm which the neural network was part of.

class Net3L(torch.nn.Module):
    def __init__(self, n_feature, n_hidden1,n_hidden2, n_hidden3, n_output):
        super(Net3L, self).__init__()
        self.hidden1 = torch.nn.Linear(n_feature, n_hidden1)  
        self.hidden2 = torch.nn.Linear(n_hidden1, n_hidden2)   # hidden layer
        self.hidden3 = torch.nn.Linear(n_hidden2, n_hidden3)
        self.predict = torch.nn.Linear(n_hidden3, n_output)   # output layer
    def forward(self, x):
        x = torch.nn.LeakyReLU()(self.hidden1(x))     # activation function for hidden layers
        x = torch.nn.LeakyReLU()(self.hidden2(x))
        x = torch.nn.LeakyReLU()(self.hidden3(x))
        x = self.predict(x)
        return x

fn = types.SimpleNamespace(Net3L=Net3L)

def normalize(x,xmin1,xmax1,xmean1, xstd1, option=1):
    if(option==1):
        x=(x-xmin1)/(xmax1-xmin1)
    elif(option==4):
        x=x       
    return x

def inv_normalize(x, xmin1, xmax1, xmean1, xstd1, option=1):
    if(option==1):
        x=x*(xmax1-xmin1)+xmin1

    elif(option==4):
        x=x
    return x
ftool = types.SimpleNamespace(
    normalize=normalize,
    inv_normalize=inv_normalize
)

class nn_model():
    def __init__(self, info, nv, act,
                 learning_rate, batch_size, epochs,
                 nn_state_dict, train1, test_loss1, test_loss2):
        
        self.info = info
        self.nv = nv
        self.activation = act
        self.learning_rate = learning_rate
        self.batch_size = batch_size
        self.epochs = epochs

        self.test_loss1 = test_loss1
        self.test_loss2 = test_loss2

        self.xlabelv = train1.xv_df.keys()
        self.ylabelv = train1.yv_df.keys()

        self.glintmask = train1.glintmask
        self.name = train1.name
        self.xv_normalize_option = train1.xv_normalize_option
        self.yv_normalize_option = train1.yv_normalize_option

        self.xmin = train1.xmin
        self.xmax = train1.xmax
        self.ymin = train1.ymin
        self.ymax = train1.ymax

        self.xmean = train1.xmean
        self.xstd = train1.xstd
        self.ymean = train1.ymean
        self.ystd = train1.ystd

        nx = len(self.xlabelv)
        ny = len(self.ylabelv)

        self.nn = fn.Net3L(nx, nv[0], nv[1], nv[2], ny)
        self.nn.load_state_dict(nn_state_dict())

        self.xv_range_df = train1.xv_range_df
        self.angle_labelv = ['zen', 'az', 'solzen']
        if self.angle_labelv[0] in self.xv_range_df.keys():
            self.angle_range_df = self.xv_range_df[self.angle_labelv]
            self.coeff_range_df = self.xv_range_df.drop(columns=self.angle_labelv)
        else:
            print("No angles in training data")

    def forward(self, device, xpv):
        xpv1 = ftool.normalize(xpv, self.xmin.values, self.xmax.values,
                               self.xmean.values, self.xstd.values, self.xv_normalize_option)
        output = self.nn(torch.Tensor(xpv1).to(device)).cpu().data.numpy()
        return ftool.inv_normalize(output,
                                   self.ymin.values, self.ymax.values,
                                   self.ymean.values, self.ystd.values, self.yv_normalize_option)
# Recreate modules which were in original FastMAPOL/component algorithm 
sys.modules['fastmapol'] = types.ModuleType('fastmapol')
sys.modules['fastmapol.train'] = types.ModuleType('train')
sys.modules['fastmapol.net'] = types.ModuleType('net')
sys.modules['fastmapol.tool'] = types.ModuleType('tool')


sys.modules['fastmapol.train'].nn_model = nn_model
sys.modules['fastmapol.net'].Net3L = Net3L
sys.modules['fastmapol.tool'].normalize = normalize
sys.modules['fastmapol.tool'].inv_normalize = inv_normalize
#sys.modules['fastmapol.train'].nn_model = nn_model

3. Functions to load L1C and ancilliary data#

  • Here, we define functions to load L1C data and ancilliary data.

  • The viewing geometry are changed according to convention used in NN training.

def anc_data_reader(file):
    ds1 = xr.open_dataset(file)
    lat = ds1['latitude'].values
    lon = ds1['longitude'].values
    o3 = ds1['TO3'].values/345.23947 #divided by standard US atm, as in training of FastMAPOL/component's neural network
    rh = ds1['RH'].values[0,:,:]
    ps = ds1['SLP'].values/100
    return fill_nearest(o3), fill_nearest(ps)

def l1c_data_reader(file):    
    ds1 = xr.open_dataset(file, group='geolocation_data')
    ds1s = xr.open_dataset(file, group='sensor_views_bands')
    ds1o = xr.open_dataset(file, group='observation_data',mask_and_scale=True)
    
    lat = ds1['latitude'].values
    lon = ds1['longitude'].values
    
    wavelengths = ds1s['intensity_wavelength'].values[0]
    intensity = ds1o['i'].values[:,:,:,:]
    f0 = ds1s['intensity_f0'].values[0]
    
    solzen = ds1['solar_zenith_angle'].values[:,:,:]
    solzen1=np.reshape(solzen,(*solzen.shape,1))* np.ones((1, 1,1, intensity.shape[3]))
    ref=np.pi*intensity/(np.cos(solzen1*np.pi/180.0)*f0)
    
    zen = ds1['sensor_zenith_angle'].values[:,:,:]
    az0 = ds1['sensor_azimuth_angle'].values[:,:,:]
    solaz = ds1['solar_azimuth_angle'].values [:,:,:]   
    az=set_az(set_az0(az0, solaz, flag_bin2sun=True)) 

    idx_wv1 = np.argmin(np.abs(wavelengths - 360))
    idx_wv2 = np.argmin(np.abs(wavelengths - 871))   
    return lat,lon,solzen,ref[:,:,:,idx_wv1:idx_wv2],wavelengths[idx_wv1:idx_wv2],zen,az

def l1c_data_reader1(file):
    ds1= xr.open_dataset(file)
    datatree = xr.open_datatree(file)
    ds1 = xr.merge(datatree.to_dict().values())
    lat = ds1['latitude'].values
    lon = ds1['longitude'].values
    
    wavelengths = ds1['intensity_wavelength'].values[0]
    intensity = ds1['i'].values[:,:,:,:]
    f0 = ds1['intensity_f0'].values[0]
    
    solzen = ds1['solar_zenith_angle'].values[:,:,:]
    solzen1=np.reshape(solzen,(*solzen.shape,1))* np.ones((1, 1,1, intensity.shape[3]))
    ref=np.pi*intensity/(np.cos(solzen1*np.pi/180.0)*f0)
    
    zen = ds1['sensor_zenith_angle'].values[:,:,:]
    az0 = ds1['sensor_azimuth_angle'].values[:,:,:]
    solaz = ds1['solar_azimuth_angle'].values [:,:,:]   
    az=set_az(set_az0(az0, solaz, flag_bin2sun=True)) 

    idx_wv1 = np.argmin(np.abs(wavelengths - 360))
    idx_wv2 = np.argmin(np.abs(wavelengths - 871)) 
    
    return lat,lon,solzen,ref[:,:,:,idx_wv1:idx_wv2],wavelengths[idx_wv1:idx_wv2],zen,az


def fill_nearest(arr):
    # This function is used to fill the ancilliary data in missing pixels from neighboring pixels.
    nan_mask = np.isnan(arr)
    arr_filled = np.copy(arr)
    arr_filled[nan_mask] = 0 

    # Perform nearest neighbor interpolation to fill NaNs
    nearest_idx = ndimage.distance_transform_edt(nan_mask, return_distances=False, return_indices=True)
    filled_arr = arr_filled[tuple(nearest_idx)]

    return filled_arr


def set_az0(az, solaz, flag_bin2sun=True):
    """Compute azimuth angle for the ray
    Note:
    bin2sun, used in pace l1c convention
    az0: in the range of [0,360]
    az: in the range of [0,180]
    """
    #
    if(flag_bin2sun):
        #bin2sun, used in pace l1c convention
        tmp1=az-solaz+180.0
    else:
        #sun2bin direction
        tmp1=az-solaz

    tmp1[tmp1>360]=tmp1[tmp1>360]-360
    tmp1[tmp1<0]=-tmp1[tmp1<0]
    return tmp1

def set_az(az0):
    """
    compute the az which will be used in NN
    ensure that phi is between 0-180
    """
    az = az0.copy()
    az[az<0.0] += 360.0
    az[az>180.0] = 360.0 - az[az>180.0]
    return az
    


def rgb_image(ax,ref,wavelengths):
    def find_closest(wavelength_array, target_nm):
        return np.argmin(np.abs(wavelength_array - target_nm))
    
    idx_blue = find_closest(wavelengths, 440)
    idx_green = find_closest(wavelengths, 550)
    idx_red = find_closest(wavelengths, 670)    

    R = ref[:, :,idx_red]
    G = ref[:, :,idx_green]
    B = ref[ :, :,idx_blue]
            
    def color_scaling(R, G, B):
        # Stack
        rgb = np.stack([R, G, B], axis=-1)
        # Normalize

        rgb = np.clip(rgb / 0.4, 0, 1)  
        # Gamma scaling
        rgb = rgb** (1 /2.2)#gamma  scaling
        return rgb
       
    rgb = color_scaling(R, G, B) 

    ax.set_extent([lon.min(), lon.max(), lat.min(), lat.max()], crs=ccrs.PlateCarree())
    ax.coastlines(resolution='10m', linewidth=0.5)
    ax.gridlines(draw_labels=True, color='gray', linewidth=0.3)

    ax.pcolormesh(lon, lat, rgb,
              transform=ccrs.PlateCarree(), shading='auto')

3. Function to load NN model#

def load_nn(device,sharedpath):
    nn_path=sharedpath+'RayleighNN_FastMAPOL_component.pk'
    nn=pickle.load(open(nn_path,'rb'))
    nn.nn.to(device)    
    return nn

4. Understanding the trained model#

nn=load_nn(device,sharedpath)

print(nn.activation)
print(nn.name)
print(nn.nv)
print(nn.xlabelv)
LeakyReLU
refas
[600, 300, 150]
Index(['zen', 'az', 'solzen', 'wndspd', 'aod', 'alh', 'fmf', 'ss', 'fnai',
       'bc', 'brc', 'rh', 'o3', 'ps'],
      dtype='object')

5. Functions to create input vector for Neural network and get Rayleigh reflecntance at NN wavelengths#

## Set aerosol parameters to very low value. The parameters include AOD, ALH, FmF, SS, FNAI, BC, BrC and RH

def nn_input_vector(zen, az, solzen, o3, ps):
    H, W = ps.shape
    inputs = np.zeros((H, W, 14), dtype=np.float32)
    inputs[..., 0] = zen
    inputs[..., 1] = az
    inputs[..., 2] = solzen
    inputs[..., 3:12] = 0.001  # Fixed aerosol-related inputs
    inputs[..., 12] = o3
    inputs[..., 13] = ps
    return inputs
def refl_nn(device, inputs, nn):    
    ref=nn.forward(device, inputs)
    return ref

6. Function to interpolate Rayleigh reflectance at NN wavelengths to finer resolutions#

This function interpolates Rayleigh reflectance predicted at 13 discrete wavelengths (using a neural network) to a finer spectral resolution using the physical Rayleigh scattering relationship:

R(λ) = c / λ⁴

To estimate the scale factor c per pixel, we use a least-squares fit over the 13 known reflectances. The optimal c minimizes the squared error between the predicted reflectance and the model:

c = Σ(R_i × (1/λ_i⁴)) / Σ((1/λ_i⁴)²)

Where:

  • R_i is the reflectance at wavelength λ_i

  • The denominator ensures the fit is optimal in a least-squares sense

wl_nn = np.array([385, 400, 410, 440, 470, 490, 510, 530, 550, 620, 670, 740, 870], dtype=np.float32)
def rayleigh_ref_interp(ref_nn, wl_fine):
    """
    refl_nn: (H, W, 13) — NN-predicted Rayleigh reflectance at coarse wavelengths
    wl_fine: target wavelengths to interpolate to
    returns: (H, W, wl_fine.shape) — interpolated Rayleigh reflectance
    """
    H, W, _ = ref_nn.shape
    wl_nn = np.array([385, 400, 410, 440, 470, 490, 510, 530, 550, 620, 670, 740, 870], dtype=np.float32)
    wl_fine = wl_fine

    rayleigh_nn = 1.0 / (wl_nn ** 4)       # (13,)
    rayleigh_fine = 1.0 / (wl_fine ** 4)   # (286,)

    # Least-squares fit: c = sum(R_i * (1/λ⁴)) / sum((1/λ⁴)^2)
    
    numerator = np.sum(ref_nn * rayleigh_nn, axis=2)    # (H, W)
    denominator = np.sum(rayleigh_nn ** 2)               # scalar
    c = numerator / denominator                          # (H, W)

    return c[:, :, np.newaxis] * rayleigh_fine           # (H, W, 286)       

7. Rayleigh correction for OCI/SPEXone/HARP2 measurements#

Load L1C and ancilliary data#

NASA earthdata login#

auth = earthaccess.login(persist=True)
OCI = earthaccess.search_data(
    short_name="PACE_OCI_L1C_SCI",
    temporal=("2024-07-20T14:04:30","2024-07-20T14:04:31"),
#    cloud_cover=clouds,
    count=1
)
Spexone = earthaccess.search_data(
    short_name="PACE_SPEXone_L1C_SCI",
    temporal=("2024-07-20T14:04:30","2024-07-20T14:04:31"),
#    cloud_cover=clouds,
    count=1
)
# Harp2 = earthaccess.search_data(
#     short_name="PACE_HARP2_L1C_SCI",
#     temporal=("2024-07-20T14:04:30","2024-07-20T14:04:31"),
# #    cloud_cover=clouds,
#     count=1
# )
ocifile=earthaccess.open(OCI)[0]
spexfile=earthaccess.open(Spexone)[0]
#harp2filw=earthaccess.open(Harp2)[0]

For OCI#

## For OCI
fileanc=sharedpath+'PACE.20240720T140430.L1C.ANC.5km.nc'
ocifile=sharedpath+'PACE_OCI.20240720T140430.L1C.V3.5km.nc'


##For SPEXonne
#fileanc='PACE.20240720T140430.L1C.ANC.5km.spex_width.nc'
#file='PACE_SPEXONE.20240720T140430.L1C.V2.5km.nc'

lat,lon,solzen,ref_oci,wl_oci,zen,az=l1c_data_reader(ocifile)
o3,ps=anc_data_reader(fileanc)

Get reflectances (total, rayleigh and rayleigh corrected)#

def get_ref(device,sharedpath,toa_ref,wavelengths,inda):
    inputp=nn_input_vector(zen[:,:,inda],az[:,:,inda],solzen[:,:,inda],o3,ps)
    nn = load_nn(device,sharedpath)
    ref_nn=refl_nn(device,inputp,nn)
    ray_ref=rayleigh_ref_interp(ref_nn, wavelengths)
#    total_ref=toa_ref[:,:,inda,:]
    corr_ref=toa_ref[:,:,inda,:]-ray_ref    
    return toa_ref[:,:,inda,:], ray_ref, corr_ref, ref_nn

Rayleigh correction for arbitrary pixel from OCI#

total_ref, ray_refoci, corr_ref, ref_nn=get_ref(device,sharedpath,ref_oci,wl_oci,inda=1)

plt.plot(wl_oci,total_ref[300,260,:],c='b', label='Total TOA reflectance')
plt.plot(wl_nn, ref_nn[300,260,:],linestyle='none', marker='x',color='b',label='NN reflectances')
plt.plot(wl_oci,ray_refoci[300,260,:],c='gray', label='Rayleigh reflectance')
plt.plot(wl_oci,corr_ref[300,260,:],c='g',label='Rayleigh corrected TOA reflectance')
plt.legend()

#plt.title('clear sky pixels', fontsize=20)
# ax[1].plot(wavelengths,total_ref[150,400,:],c='b')
# ax[1].plot(wavelengths,ray_ref[150,400,:],c='gray')
# ax[1].plot(wavelengths,corr_ref[150,400,:],c='g')
# ax[1].set_title('Over cloudy pixels', fontsize=20)
#ax[1].legend()
<matplotlib.legend.Legend at 0x7f5ada016d50>
../../_images/c5abad8419ab5bf9e8ef7efee1ce518fbf7558a4f530bb259768fe391b5ec35a.png

Rayleigh correction for multiangle measurements from SPEXone#

  • SPEXone has multiangle measurements.

  • If OCI view has clouds, other angle from spexone can be cloud free supplement information.

#For SPEXonne
anc_sp1=sharedpath+'PACE.20240720T140430.L1C.ANC.5km.spex_width.nc'
spexfile=sharedpath+'PACE_SPEXONE.20240720T140430.L1C.V2.5km.nc'

lat,lon,solzen,ref_sp1,wl_sp1,zen,az=l1c_data_reader1(spexfile)
o3,ps=anc_data_reader(anc_sp1)
fig,ax=plt.subplots(1,5,figsize=[30,6])
for i in range(5):
    total_refs, ray_refs, corr_refs, ref_nn=get_ref(device,sharedpath,ref_sp1,wl_sp1,i)
    ax[i].plot(wl_sp1,total_refs[300,16,:],c='b', label='Total TOA reflectance')
    ax[i].plot(wl_sp1,ray_refs[300,16,:],c='gray',label='Rayleigh reflectance')
    ax[i].plot(wl_sp1,corr_refs[300,16,:],c='g', label='Rayleigh corrected TOA reflectance')
    ax[i].set_title('iang='+str(i),fontsize=14)
    ax[0].legend(fontsize=11)
../../_images/17b3632278415d905a835496bfd83462787c75abc395b1a08c275ef49922f3d8.png

Lets look at RGB plot and Rayleigh corrected RGB plot#

lat,lon,solzen,ref_oci,wl_oci,zen,az=l1c_data_reader(ocifile)
o3,ps=anc_data_reader(fileanc)

total_ref, ray_ref, corr_ref, ref_nn=get_ref(device,sharedpath,ref_oci,wl_oci,inda=1)

fig,ax=plt.subplots(1,2,figsize=[30,9],subplot_kw={'projection': ccrs.PlateCarree()})
rgb_image(ax[0],total_ref,wl_oci)
rgb_image(ax[1],corr_ref,wl_oci)
ax[0].set_title('Original RGB Image',fontsize=25)
ax[1].set_title('RGB Image after Rayleigh correction',fontsize=25)
Text(0.5, 1.0, 'RGB Image after Rayleigh correction')
../../_images/6d143ff383f98b7b53d7bf66dac2710ea23381592938fe16338cd5a7ef1019d7.png

RGB plot after removing non clear sky pixels#

  • The reflectances on longer wavelengths (eg. 670) from clear sky pixels over ocean are very low.

  • non clear sky pixels are masked using threshold reflectance at 670 to be 0.15 for clear disctinction of ocean color.

idx_670 = np.argmin(np.abs(wl_oci - 670))
# Extract reflectance at 670 nm
refl_670 = total_ref[:, :, idx_670]

# Define cloud threshold
clear_mask = refl_670 < 0.15

clear_total_ref=np.where(clear_mask[:, :, np.newaxis], total_ref, 0)
clear_corr_ref=np.where(clear_mask[:, :, np.newaxis], corr_ref, 0)

fig,ax=plt.subplots(1,2,figsize=[30,9],subplot_kw={'projection': ccrs.PlateCarree()})
rgb_image(ax[0],clear_total_ref,wl_oci)
rgb_image(ax[1],clear_corr_ref,wl_oci)
ax[0].set_title('Original RGB Image',fontsize=25)
ax[1].set_title('RGB Image after Rayleigh correction',fontsize=25)
Text(0.5, 1.0, 'RGB Image after Rayleigh correction')
../../_images/54e8e31d6685c936a08720c12f6978cd599c3ba8f0fa3fe7b9b33360ad891888.png