QG3 Extension

This page documents the QG3-based SFNO layers and utility functions defined to work with the QG3 package

ESM_PINO.SFNOMethod
SFNO(
    ggsh::QG3.GaussianGridtoSHTransform,
    shgg::QG3.SHtoGaussianGridTransform;
    in_channels,
    out_channels,
    hidden_channels,
    n_layers,
    num_encoder_layers,
    num_decoder_layers,
    lifting_channel_ratio,
    projection_channel_ratio,
    channel_mlp_expansion,
    activation,
    positional_embedding,
    inner_skip,
    outer_skip,
    operator_type,
    use_norm,
    downsampling_factor,
    modes,
    gpu,
    batch_size,
    soft_gating,
    bias
) -> SFNO{E, L, B, P, ESM_PINOQG3Ext.ESM_PINOQG3} where {E, L, B, P}

Spherical Fourier Neural Operator (SFNO) layer combining positional embeddings, spectral kernels, and channel MLPs.

This layer implements the SFNO architecture on the sphere, optionally using Zonal Symmetric Kernels (ZSK) following the approach described in Spherical Fourier Neural Operators: Learning Stable Dynamics on the Sphere.

Arguments

  • ggsh::QG3.GaussianGridtoSHTransform: Precomputed grid-to-SH transform.
  • shgg::QG3.SHtoGaussianGridTransform: Precomputed SH-to-grid transform.
  • Other keyword arguments are the same as for the primary constructor, except modes which default is set to ggsh.output_size[1]. Also, no need to specify batch_size or gpu as these are handled in the transforms.

Returns

  • SFNO: A Lux-compatible container layer.

Details

  • Constructs lifting, SFNO blocks, and projection layers compatible with Lux.jl.
  • Positional embeddings are appended if positional_embedding="grid".
  • Supports both CPU and GPU execution.
  • Zonal Symmetric Kernels (ZSK) reduce the number of parameters and improve stability on spherical domains.

Example

using Lux, QG3, Random, NNlib, LuxCUDA

# Load precomputed QG3 parameters
qg3ppars = QG3.load_precomputed_params()[2]

# Input: [lat, lon, channels, batch]
x = rand(Float32, 32, 64, 3, 10)


# Construct SFNO layer using secondary constructor
ggsh = QG3.GaussianGridtoSHTransform(qg3ppars, 32, N_batch=size(x,4))
shgg = QG3.SHtoGaussianGridTransform(qg3ppars, 32, N_batch=size(x,4))
model2 = SFNO(ggsh, shgg;
    modes=15,
    in_channels=3,
    out_channels=3,
    hidden_channels=32,
    n_layers=4,
    lifting_channel_ratio=2,
    projection_channel_ratio=2,
    channel_mlp_expansion=2.0,
    positional_embedding="no_grid",
    outer_skip=true,
    zsk=true
)

# Setup parameters and state
rng = Random.default_rng(0)
ps, st = Lux.setup(rng, model2)

# Forward pass
y, st = model2(x, ps, st)

# Compute gradients
using Zygote
gr = Zygote.gradient(ps -> sum(model2(x, ps, st)[1]), ps)
source
ESM_PINO.SFNOMethod

Modified SFNO constructor with variable encoder/decoder depths.

New Arguments

  • num_encoder_layers::Int=2: Number of layers in the encoder (lifting)
  • num_decoder_layers::Int=2: Number of layers in the decoder (projection)

Notes

  • When num_encoder_layers=2 and num_decoder_layers=2, this behaves identically to the original implementation
  • lifting_channel_ratio controls the hidden dimension ratio for the encoder
  • projection_channel_ratio controls the hidden dimension ratio for the decoder
source
ESM_PINO.SFNO_BlockMethod
SFNO_Block(
    channels::Int64,
    ggsh::QG3.GaussianGridtoSHTransform,
    shgg::QG3.SHtoGaussianGridTransform;
    modes,
    expansion_factor,
    activation,
    skip,
    operator_type,
    use_norm,
    soft_gating,
    bias
) -> ESM_PINO.SFNO_Block{ESM_PINOQG3Ext.ESM_PINOQG3}

A block that combines a spherical kernel with a channel MLP. Expects input in (spatial..., channel, batch) format.

Arguments

  • channels::Int: Number of input/output channels
  • ggsh::GaussianGridtoSHTransform: Transformation from Gaussian grid to spherical harmonics
  • shgg::SHtoGaussianGridTransform: Transformation from spherical harmonics back to Gaussian
  • modes::Int=ggsh.output_size[1]: Number of spherical harmonic modes to retain (default: ggsh.output_size[1])
  • expansion_factor::Real=2.0: Expansion factor for the ChannelMLP (default: 2.0)
  • activation: Activation function applied after combining spatial and spectral branches (default: NNlib.gelu)
  • skip::Bool=true: Whether to include a skip connection (default: true)
  • zsk::Bool=false: Whether to use Zonal Symmetric Kernels (ZSK) (default: false)

Returns

  • SFNO_Block: A Lux-compatible layer operating on 4D arrays [lat, lon, channels, batch].

Fields

  • spherical_kernel::SphericalKernel: Spherical kernel layer
  • channel_mlp::ChannelMLP: Channel-wise MLP layer
  • channels::Int: Number of input/output channels
  • skip::Bool: Whether to include a skip connection

-activation::Function: Activation function applied after the block

Details

  • The input is processed by a SphericalKernel followed by a ChannelMLP
  • If skip is true, the input is added to the output (residual connection)
source
ESM_PINO.SFNO_BlockMethod
SFNO_Block(
    channels::Int64,
    pars::QG3ModelParameters;
    modes,
    batch_size,
    expansion_factor,
    activation,
    skip,
    gpu,
    operator_type,
    use_norm,
    soft_gating,
    bias
) -> ESM_PINO.SFNO_Block{ESM_PINOQG3Ext.ESM_PINOQG3}

A block that combines a spherical kernel with a channel MLP. Expects input in (spatial..., channel, batch) format.

Arguments

  • channels::Int: Number of input/output channels
  • pars::QG3ModelParameters: Precomputed QG3 model parameters (QG3ModelParameters)
  • modes::Int=pars.L: Number of spherical harmonic modes to retain (default: pars.L)
  • batch_size::Int=1: Batch size for transforms (default: 1)
  • expansion_factor::Real=2.0: Expansion factor for the ChannelMLP (default: 2.0)
  • activation: Activation function applied after combining spatial and spectral branches (default: NNlib.gelu)
  • skip::Bool=true: Whether to include a skip connection (default: true)
  • gpu::Bool=true: Whether to use GPU (default: true)
  • zsk::Bool=false: Whether to use Zonal Symmetric Kernels (ZSK) (default: false)

Returns

  • SFNO_Block: A Lux-compatible layer operating on 4D arrays [lat, lon, channels, batch].

Fields

  • spherical_kernel::SphericalKernel: Spherical kernel layer
  • channel_mlp::ChannelMLP: Channel-wise MLP layer
  • channels::Int: Number of input/output channels
  • skip::Bool: Whether to include a skip connection

-activation::Function: Activation function applied after the block

Details

  • The input is processed by a SphericalKernel followed by a ChannelMLP
  • If skip is true, the input is added to the output (residual connection)
source
ESM_PINO.SphericalKernelMethod
SphericalKernel(
    hidden_channels::Int64,
    ggsh::QG3.GaussianGridtoSHTransform,
    shgg::QG3.SHtoGaussianGridTransform;
    use_norm,
    modes,
    operator_type,
    inner_mixing,
    bias
) -> ESM_PINO.SphericalKernel{ESM_PINOQG3Ext.ESM_PINOQG3}

Construct a SphericalKernel layer using precomputed transforms.

Arguments

  • hidden_channels::Int: Number of channels
  • ggsh::GaussianGridtoSHTransform: Transformation from Gaussian grid to spherical harmonics
  • shgg::SHtoGaussianGridTransform: Transformation from spherical harmonics back to Gaussian grid
  • activation: Activation function applied after combining spatial and spectral branches (default: NNlib.gelu)
  • modes::Int=ggsh.output_size[1]: Number of spherical harmonic modes to retain (default: ggsh.output_size[1])
  • zsk::Bool=false: Whether to use Zonal Symmetric Kernels (ZSK) (default: false)

Returns

  • SphericalKernel: A Lux-compatible layer operating on 4D arrays [lat, lon, channels, batch].

Fields

  • spatial_conv::P: 1x1 convolution operating directly in the spatial domain
  • spherical_conv::SphericalalConv: Spherical convolution layer

-norm::Union{Lux.InstanceNorm, Lux.NoOpLayer}: Optional normalization layer

source
ESM_PINO.SphericalKernelMethod
SphericalKernel(
    hidden_channels::Int64,
    pars::QG3ModelParameters;
    use_norm,
    modes,
    batch_size,
    gpu,
    operator_type,
    inner_mixing,
    bias
) -> ESM_PINO.SphericalKernel{ESM_PINOQG3Ext.ESM_PINOQG3}

Combines a SphericalConv layer with a 1x1 convolution in parallel, followed by an activation function. Expects input in (spatial..., channel, batch) format.

Arguments

  • hidden_channels: Number of channels
  • pars: Precomputed QG3 model parameters (QG3ModelParameters)
  • activation: Activation function applied after combining spatial and spectral branches (default: NNlib.gelu)
  • modes: Number of spherical harmonic modes to retain (default: pars.L)
  • batch_size: Batch size for transforms (default: 1)
  • gpu: Whether to use GPU (default: true)
  • zsk: Whether to use Zonal Symmetric Kernels (ZSK) (default: false)

#Returns

  • SphericalKernel: A Lux-compatible layer operating on 4D arrays `[lat,

Fields

  • spatial_conv::P: 1x1 convolution operating directly in the spatial domain
  • spherical_conv::SphericalalConv: Spherical convolution layer

-norm::Union{Lux.InstanceNorm, Lux.NoOpLayer}: Optional normalization layer

Details

  • The input is processed in parallel by a 1x1 convolution and a spherical convolution
  • Outputs from both branches are summed and passed through the activation
  • Useful for mixing local (spatial) and global (spectral) information
source
ESM_PINOQG3Ext.GaussianGridInfoType
GaussianGridInfo

Structure containing information about a Gaussian grid resolution.

Fields

  • truncation::Int: Spectral truncation number (e.g., 31 for T31)
  • nlat::Int: Number of latitude points
  • nlon::Int: Number of longitude points
  • km_at_equator::Float64: Approximate grid spacing at equator in km
  • deg_at_equator::Float64: Approximate grid spacing at equator in degrees
  • description::String: Human-readable description
source
ESM_PINOQG3Ext.RemapPlanType
RemapPlan

Precomputed plan for efficient array remapping. Stores source and destination indices to avoid recomputation.

source
ESM_PINOQG3Ext.calculate_gaussian_grid_sizeMethod
calculate_gaussian_grid_size(truncation::Int) -> Tuple{Int, Int}

Calculate Gaussian grid dimensions from spectral truncation number using standard formulas.

For a spectral truncation T, the standard relationships are:

  • nlat = (truncation + 1) * 3 / 2 (for reduced grids, varies slightly)
  • nlon = 2 * nlat (for regular grids)

Arguments

  • truncation::Int: Spectral truncation number

Returns

  • Tuple{Int, Int}: (nlat, nlon)
source
ESM_PINOQG3Ext.compute_ACCMethod
compute_ACC(X_pred, X_true, lats, ltm; l=1)

Compute the Anomaly Correlation Coefficient (ACC) for a variable at time-step l.

Arguments

  • X_pred: Predicted values, array with dimensions [lat, lon, time] or [lat, lon]
  • X_true: True values, array with dimensions [lat, lon, time] or [lat, lon]
  • lats: Latitude values in degrees for each latitude grid point (length NLat)
  • ltm: Long-term mean to subtract (array with dimensions [lat, lon])

Returns

  • ACC value (scalar between -1 and 1)
source
ESM_PINOQG3Ext.compute_ACCMethod
compute_ACC(X_pred, X_true, pars::QG3.QG3ModelParameters, ltm)

Compute the Anomaly Correlation Coefficient (ACC) for a variable at each time-step.

Arguments

  • X_pred: Predicted values, array with dimensions [lat, lon, time] or [lat, lon]
  • X_true: True values, array with dimensions [lat, lon, time] or [lat, lon]
  • pars::QG3.QG3ModelParameters: QG3 model parameters containing latitude weights
  • ltm: Long-term mean to subtract (array with dimensions [lat, lon])

Returns

  • ACC value(s): scalar if input is 2D, vector of length time if input is 3D
source
ESM_PINOQG3Ext.fine_tuningMethod
fine_tuning(
    x::AbstractArray,
    target::AbstractArray,
    x_val::AbstractArray,
    target_val::AbstractArray,
    model::SFNO,
    ps::NamedTuple,
    st::NamedTuple;
    seed,
    n_steps,
    nepochs,
    num_examples,
    num_valid,
    lr_0,
    gpu,
    parameters,
    use_physics,
    geometric,
    spectral,
    α,
    β,
    logging
) -> LuxCore.StatefulLuxLayerImpl.StatefulLuxLayer{Val{true}, SFNO{E, L, B, P, Q}, _A, NamedTuple{names, T}} where {E, L, B, P, Q, _A, names, T<:Tuple}

Fine-tune a pretrained SFNO model using an autoregressive (AR) loss function. This procedure is typically applied after initial training to improve multi-step forecast accuracy.

The function performs a short fine-tuning loop with autoregressive supervision, optionally including a physics-informed loss component.

Arguments

  • x::AbstractArray: Input data tensor.
  • target::AbstractArray: Target data tensor with shape (lat, lon, channels, batch, time).
  • model: Pretrained SFNO model to be fine-tuned.
  • ps::NamedTuple: Model parameters (from previous training).
  • st::NamedTuple: Model internal state.

Keywords

  • n_steps::Int=2: Number of autoregressive steps in the loss function.
  • maxiters::Int=5: Maximum number of fine-tuning iterations.
  • lr_0::Float64=1e-5: Learning rate for fine-tuning.
  • parameters::QG3_Physics_Parameters=QG3_Physics_Parameters(): Physical parameters used in the loss.
  • use_physics::Bool=true: Include physics-informed component in the loss if true.
  • geometric::Bool=true: Use geometric formulation of the physics loss.
  • α::Float32=0.7f0: Weighting factor between physics and data loss terms.

Returns

  • StatefulLuxLayer{true}: Fine-tuned model instance with updated parameters and states.

Notes

  • The target tensor must have five dimensions, with the number of autoregressive steps as the fifth dimension.
  • The time dimension (size(target, 5)) must match the number of autoregressive steps (n_steps).

Example

# Fine-tune a pretrained SFNO model for multi-step forecasting
ft_model = fine_tuning(x_val, y_val, pretrained_model, ps, st;
                       n_steps=3, maxiters=10, lr_0=1e-5)

# Evaluate fine-tuned model
pred = ft_model(x_val)
source
ESM_PINOQG3Ext.gaussian_resolution_to_gridMethod
gaussian_resolution_to_grid(resolution::AbstractString) -> Tuple{Int, Int}

Convert a Gaussian grid resolution string (e.g., "T31", "T63") to (nlat, nlon) tuple.

Arguments

  • resolution::AbstractString: Grid resolution in format "TN" where N is truncation number

Returns

  • Tuple{Int, Int}: (number of latitude points, number of longitude points)

Examples

julia> gaussian_resolution_to_grid("T31")
(48, 96)

julia> gaussian_resolution_to_grid("T63")  
(96, 192)

julia> gaussian_resolution_to_grid("T255")
(256, 512)

Throws

  • ArgumentError: If resolution is not recognized
source
ESM_PINOQG3Ext.get_truncation_from_nlatMethod
get_truncation_from_nlat(nlat::Int) -> Int

Retrieve the spectral truncation number for a given number of latitude points.

Arguments

  • nlat::Int: Number of latitude points

Returns

  • Int: Spectral truncation number (e.g., 31 for T31)

Examples

julia> get_truncation_from_nlat(48)
31

julia> get_truncation_from_nlat(96)
63

julia> get_truncation_from_nlat(256)
255

Throws

  • ArgumentError: If nlat doesn't match any known Gaussian grid resolution
source
ESM_PINOQG3Ext.make_QG3_lossMethod
make_QG3_loss(pars::QG3_Physics_Parameters;
              α=0.5f0,
              use_physics::Bool=true,
              geometric::Bool=false)

Create a composite QG3 loss function suitable for Lux training. Returns a callable (model, ps, st, (input, target)) -> (loss, st, metrics).

source
ESM_PINOQG3Ext.make_autoregressive_lossMethod
make_autoregressive_loss(QG3_loss::Function; steps::Int, sequential::Bool=true)

Create an autoregressive loss function that rolls out predictions over steps and accumulates the loss defined by QG3_loss.

Arguments

  • QG3_loss: A loss function of the form (model, ps, st, (input, target)) -> (loss, st, details)
  • steps: Number of autoregressive rollout steps
  • sequential: If true, predictions are fed sequentially (standard autoregressive); if false, all predictions are computed and compared in batch for efficiency.

Returns

  • A loss function (model, ps, st, (u_t1, targets)) -> (loss, st, details)
source
ESM_PINOQG3Ext.preprocess_dataMethod
preprocess_data(solu::AbstractArray;
    noise_level::Real=0.0,
    normalize::Bool=true,
    channelwise::Bool=false,
    to_gpu::Bool=false,
    noise_type::Symbol=:gaussian,
    slice_range::Union{Nothing, UnitRange, Tuple}=nothing,
    train_fraction::Real=0.8
)

Preprocess simulation data with normalization, noise injection, and train/validation splitting.

Arguments

  • solu::AbstractArray: Raw simulation data array of shape (lat, lon, channel, time)

Keyword Arguments

  • noise_level::Real=0.0: Standard deviation of noise to add (0.0 = no noise)
  • normalize::Bool=true: Whether to normalize the data
  • channelwise::Bool=false: If true, normalize each channel independently
  • to_gpu::Bool=false: If true, transfer data to GPU after preprocessing
  • noise_type::Symbol=:gaussian: Type of noise (:gaussian, :uniform, :salt_pepper)
  • slice_range::Union{Nothing, UnitRange, Tuple}=nothing: Optional range(s) to slice data
  • train_fraction::Real=0.8: Fraction of data to use for training (rest used for validation)

Returns

  • q_0_train: Training initial conditions
  • q_evolved_train: Training evolved states
  • q_0_val: Validation initial conditions
  • q_evolved_val: Validation evolved states
  • μ: Mean(s) used for normalization
  • σ: Std(s) used for normalization
  • normalization_params: NamedTuple with normalization metadata
source
ESM_PINOQG3Ext.preprocess_dataMethod
preprocess_data(;
    noise_level::Real=0.0,
    normalize::Bool=true,
    channelwise::Bool=false,
    to_gpu::Bool=false,
    noise_type::Symbol=:gaussian,
    slice_range::Union{Nothing, UnitRange, Tuple}=nothing,
    train_fraction::Real=0.8
)

Preprocess simulation data with normalization, noise injection, and train/validation splitting.

Keyword Arguments

  • noise_level::Real=0.0: Standard deviation of noise to add (0.0 = no noise)
  • normalize::Bool=true: Whether to normalize the data
  • channelwise::Bool=false: If true, normalize each channel independently
  • to_gpu::Bool=false: If true, transfer data to GPU after preprocessing
  • noise_type::Symbol=:gaussian: Type of noise (:gaussian, :uniform, :salt_pepper)
  • slice_range::Union{Nothing, UnitRange, Tuple}=nothing: Optional range(s) to slice data
  • train_fraction::Real=0.8: Fraction of data to use for training (rest used for validation)

Returns

  • q_0_train: Training initial conditions
  • q_evolved_train: Training evolved states
  • q_0_val: Validation initial conditions
  • q_evolved_val: Validation evolved states
  • μ: Mean(s) used for normalization
  • σ: Std(s) used for normalization
  • normalization_params: NamedTuple with normalization metadata

Examples

# Basic usage with 80% train, 20% validation
q0_tr, qe_tr, q0_val, qe_val, μ, σ, params = preprocess_data(
    normalize=true,
    train_fraction=0.8
)

# Custom train/val split with noise
q0_tr, qe_tr, q0_val, qe_val, μ, σ, params = preprocess_data(
    noise_level=0.01,
    normalize=true,
    channelwise=true,
    train_fraction=0.7,
    to_gpu=true
)
source
ESM_PINOQG3Ext.qg3pars_constructor_helperMethod
qg3pars_constructor_helper(
    L::Int64,
    n_lat::Int64;
    n_lon,
    iters,
    tol,
    NF
) -> QG3ModelParameters{Float32, Int64, Vector{Float32}, Matrix{Float32}}

Helper function to hook the constructor for QG3ModelParameters using a Gaussian grid. Generates latitude/longitude points and initializes empty topography and land/sea mask. Used mainly to handle SH transforms.

Arguments

  • L::Int: Spectral truncation level (maximum degree).

  • n_lat::Int: Number of Gaussian latitudes.

Keywords

  • n_lon::Int=2*n_lat: Number of longitudes (default: twice the latitude count).

  • iters::Int=100: Maximum number of iterations for Gaussian grid convergence.

  • tol::Real=1e-8: Convergence tolerance.

  • NF::Type{<:AbstractFloat}=Float32: Number format for outputs.

Returns

  • QG3ModelParameters: Model parameters including grid coordinates, topography (h),

and land-sea mask (LS).

Example

pars = qg3pars_constructor_helper(42, 64)
source
ESM_PINOQG3Ext.reorderQG3_indexesMethod
reorderQG3_indexes(A::AbstractMatrix)

Reorder a 2D array by applying circular shifts to each column. Column j gets shifted by j÷2 positions downward.

GPU-compatible and Zygote-differentiable.

source
ESM_PINOQG3Ext.reorderQG3_indexes_4dMethod
reorderQG3_indexes_4d(A::AbstractArray{T,4})

Reorder a 4D array by applying the 2D reordering to dimensions 2 and 3. Dimensions 1 and 4 are preserved.

GPU-compatible and Zygote-differentiable.

source
ESM_PINOQG3Ext.stack_time_stepsMethod
stack_time_steps(
    data::AbstractArray{T, 4},
    time_steps::Int64;
    dt,
    N_sims,
    gpu
) -> Any

Convert a 4D tensor of sequential data (lat, lon, channels, batch) into a 5D tensor suitable for autoregressive training or evaluation. The function constructs overlapping sequences along the batch dimension, each containing time_steps consecutive snapshots.

Arguments

  • data::AbstractArray{T,4}: Input data tensor with dimensions (lat, lon, channels, batch).
  • time_steps::Int: Number of consecutive time steps to include in each sequence.

Returns

  • AbstractArray{T,5}: A 5D tensor of shape (lat, lon, channels, time_steps, n_sequences), where n_sequences = batch - time_steps + 1.

Notes

  • The resulting array can be used as autoregressive training targets for multi-step prediction.
  • Sequences are created by sliding a window of length time_steps along the batch axis.

Example

# Input: 4D array with 10 time samples
data = rand(Float32, 64, 128, 3, 10)

# Stack into 5D sequences of 4 time steps each
seq_data = stack_time_steps(data, 4)

@assert size(seq_data) == (64, 128, 3, 4, 7)
source
ESM_PINOQG3Ext.train_modelMethod
train_model(
    x::AbstractArray,
    target::AbstractArray,
    pars::QG3ModelParameters;
    seed,
    maxiters,
    batchsize,
    modes,
    in_channels,
    out_channels,
    hidden_channels,
    n_layers,
    lifting_channel_ratio,
    projection_channel_ratio,
    channel_mlp_expansion,
    activation,
    positional_embedding,
    inner_skip,
    outer_skip,
    operator_type,
    use_norm,
    downsampling_factor,
    lr_0,
    gpu,
    parameters,
    use_physics,
    geometric,
    α
) -> LuxCore.StatefulLuxLayerImpl.StatefulLuxLayer{Val{true}, SFNO{E, L, B, P, Q}, _A, NamedTuple{names, T}} where {E, L, B, P, Q, _A, names, T<:Tuple}

Train an SFNO model with the possibility using a combined data-driven (simple MSE or geometrically weighted) and physics-informed loss. This function initializes the model, optimizer, and training loop, and performs iterative optimization of the model parameters.

Both standard data loss and an optional physics-informed term are tracked during training.

Arguments

  • x::AbstractArray: Input training data tensor.
  • target::AbstractArray: Target (ground truth) data tensor.
  • pars::QG3ModelParameters: Model configuration including grid and spectral parameters.

Keywords

  • seed::Int=0: Random seed for reproducibility.
  • maxiters::Int=20: Number of training iterations.
  • batchsize::Int=256: Mini-batch size for training.
  • modes::Int=pars.L: Spectral truncation level.
  • in_channels::Int=3: Number of input channels.
  • out_channels::Int=3: Number of output channels.
  • hidden_channels::Int=256: Width of the hidden feature layers.
  • n_layers::Int=4: Number of SFNO layers.
  • lifting_channel_ratio::Int=2: Ratio of lifting layer expansion.
  • projection_channel_ratio::Int=2: Ratio of projection layer contraction.
  • channel_mlp_expansion::Number=2.0: Expansion factor in channel MLP blocks.
  • activation: Activation function used in SFNO blocks (default: NNlib.gelu).
  • positional_embedding::AbstractString="grid": Type of positional embedding ("grid" or "no_grid").
  • inner_skip::Bool=true: Whether to enable residual connections inside SFNO blocks.
  • outer_skip::Bool=true: Whether to enable skip connections between lifting output and projection input.
  • zsk::Bool=false: Use zonally symmetric kernel formulation if true.
  • use_norm::Bool=false: Apply normalization layers inside SFNO blocks.
  • downsampling_factor::Int=2: Ratio of downsampling between layers.
  • lr_0::Float64=1e-3: Initial learning rate for the optimizer.
  • parameters::QG3_Physics_Parameters=QG3_Physics_Parameters(pars, batch_size=batchsize): Physical parameters used in the QG3 loss.
  • use_physics::Bool=true: Whether to include the physics-informed loss component.
  • geometric::Bool=true: Use geometrically weighted formulation for the data loss.
  • α::Float32=0.7f0: Weighting factor between physics loss and data loss.

Returns

  • StatefulLuxLayer{true}: Trained SFNO model containing learned parameters and internal state.

Example

# Initialize parameters and data
pars = qg3pars_constructor_helper(42, 64)
x, y = generate_training_data(pars)

# Train SFNO model
trained_model = train_model(x, y, pars; maxiters=100, batchsize=128, lr_0=5e-4)

# Perform inference
pred = trained_model(x)
source
ESM_PINOQG3Ext.transfer_SFNO_modelMethod
transfer_SFNO_model(
    model::SFNO,
    qg3ppars::QG3ModelParameters;
    batch_size,
    gpu
) -> SFNO{E, L, B, P, ESM_PINOQG3Ext.ESM_PINOQG3} where {E, L, B, P}

Construct a new SFNO model that replicates the architecture and parameters of an existing model, but adapts them to a new discretization (qg3ppars) and batch size. This function preserves all spectral modes, channels, and hyperparameters while adjusting the internal transform plans to match the new grid configuration.

Arguments

  • model::SFNO: Source SFNO model whose architecture and parameters will be cloned.
  • qg3ppars: Target problem parameters (e.g., grid, spectral resolution).

Keywords

  • batch_size::Int: Optional new batch size. Defaults to the batch size inferred from model.sfno_blocks.layers.layer_1.spherical_kernel.spherical_conv.plan.ggsh.FT_4d.plan.input_size[4].

Returns

  • SFNO: A new model instance with the same architecture as model, configured for the target discretization and batch size.

Example

# Original model (batch size = 32)
model = SFNO(orig_pars; batch_size=32, ...)

# Transfer model to a finer grid and larger batch
new_model = transfer_SFNO_model(model, new_pars; batch_size=64)

# Forward pass with transferred weights
ŷ = new_model(x, ps, st)
source