Source code for optframework.utils.func.jit_pbm_rhs

import numpy as np
import math
from numba import jit
import optframework.utils.func.jit_pbm_qmom as qmom
import optframework.utils.func.jit_kernel_agg as kernel_agg
import optframework.utils.func.jit_kernel_break as kernel_break
import optframework.utils.func.jit_pbm_chyqmom as chyqmom
from scipy.optimize import root
from scipy.optimize import least_squares
from numba import njit

[docs]def filter_negative_nodes(xi, wi, threshold=1e-7): """ Filter out negative nodes with small weights or raise an error if negative nodes have significant weights. Parameters: xi (array-like): Array of nodes (abscissas). wi (array-like): Array of weights corresponding to the nodes. threshold (float): Weight threshold below which negative nodes can be removed. Returns: tuple: Filtered (xi, wi) arrays with negative nodes removed if their weights are below threshold. Raises: ValueError: If a negative node has a weight exceeding the threshold. """ xi = np.array(xi) wi = np.array(wi) if len(xi) != len(wi): raise ValueError("xi and wi must have the same length.") # Indices where xi < 0 negative_indices = np.where(xi < 0)[0] # Check each negative xi for idx in negative_indices: if wi[idx] >= threshold: raise ValueError( f"Node xi[{idx}] = {xi[idx]} is negative, but its weight wi[{idx}] = {wi[idx]} exceeds the threshold {threshold}." ) # Remove elements where xi < 0 and wi < threshold valid_indices = [i for i in range(len(xi)) if xi[i] >= 0 or wi[i] >= threshold] # Filter xi and wi xi_filtered = xi[valid_indices] wi_filtered = wi[valid_indices] return xi_filtered, wi_filtered
[docs]def hyqmom_newton_correction(xi, wi, moments, method="lm"): """ Correct QMOM abscissas and weights through Newton iteration to satisfy moment equations, particularly addressing negative abscissas. Parameters: xi (array-like): Quadrature abscissas (may contain negative values). wi (array-like): Quadrature weights. moments (array-like): Target moments [M0, M1, M2, ...]. method (str): Optimization method for least squares solver. Returns: tuple: Corrected (xi, wi) arrays that better satisfy the moment equations. """ SMALL = 1e-8 # Handle negative abscissas negative_idx = xi < 0 if np.any(negative_idx): # Replace negative values with a small positive value xi[negative_idx] = SMALL else: # No correction needed if no negative values return xi, wi xi_free_idx = ~negative_idx # Indices of abscissas to optimize xi_free = xi[xi_free_idx] wi_free = wi # Define residual function for optimization def residuals(variables): """ Calculate residuals between current moments and target moments """ xi_new = xi.copy() xi_new[xi_free_idx] = variables[:len(xi_free)] wi_new = variables[len(xi_free):] # Focus on first three moments for stability return np.array([np.sum(wi_new * (xi_new ** k)) - moments[k] for k in range(3)]) # Initial guess combining free abscissas and weights x0 = np.concatenate([xi_free, wi_free]) # Solve using least squares optimization sol = least_squares(residuals, x0, method="trf", ftol=1e-3) if not sol.success: raise ValueError(f"Optimization failed to converge: {sol.message}") # Extract optimized values xi[xi_free_idx] = sol.x[:len(xi_free)] wi = sol.x[len(xi_free):] return xi, wi
@njit def get_dMdt_1d(t, moments, x_max, GQMOM, GQMOM_method, moments_norm_factor, n_add, nu, COLEVAL, CORR_BETA, G, alpha_prim, SIZEEVAL, V_unit, X_SEL, Y_SEL, pl_P1, pl_P2, BREAKRVAL, v, q, BREAKFVAL, type_flag): """ Calculate the moment derivatives for 1D population balance equations, handling agglomeration and/or breakage processes. Parameters: t (float): Current time. moments (array): Current moments. x_max (float): Maximum coordinate value for scaling. GQMOM (bool): Flag to use Generalized QMOM instead of standard QMOM. GQMOM_method (str): Method for GQMOM calculations ("gaussian"). moments_norm_factor (array): Normalization factors for moments. n_add (int): Number of additional nodes for GQMOM. nu (float): Exponent for the correction in gaussian-GQMOM. COLEVAL (int): Case for collision kernel calculation. CORR_BETA (float): Correction term for collision frequency. G (float): Shear rate [1/s]. alpha_prim (float): Primary particle interaction parameter. SIZEEVAL (int): Case for size dependency. V_unit (float): Unit volume used for concentration calculations. X_SEL (float): Size dependency parameter. Y_SEL (float): Size dependency parameter. pl_P1 (float): First parameter in power law for breakage rate. pl_P2 (float): Second parameter in power law for breakage rate. BREAKRVAL (int): Breakage rate model selector. v (float): Number of fragments in product function of power law. q (float): Parameter describing the breakage type in product function. BREAKFVAL (int): Breakage fragment distribution model selector. type_flag (str): Process type: "agglomeration", "breakage", or "mix". Returns: array: (Normalized) moment derivatives (dM/dt). Raises: ValueError: If moments are not realizable. """ dMdt = np.zeros(moments.shape) dMdt_norm = np.zeros(moments.shape) # Check if moments are realizable if moments[0] <= 0: raise ValueError("Wheeler: Moments are NOT realizable (moment[0] <= 0.0).") m = len(moments) n = m // 2 # Number of xi based on available moments adaptive = False use_central=False # Calculate quadrature nodes and weights if not GQMOM: xi, wi, n = qmom.calc_qmom_nodes_weights(moments, n, adaptive, use_central) else: xi, wi, n = qmom.calc_gqmom_nodes_weights(moments, n, n_add, GQMOM_method, nu, adaptive, use_central) # xi, wi = filter_negative_nodes(xi, wi) n += n_add # Check for negative nodes (debugging) # if np.any(xi<0): # print(f"t = {t}") # print(f"xi = {xi}, wi = {wi}") # Scale nodes back to original domain xi *= x_max # Prepare volume and radius arrays (with zero padding at index 0) V = np.zeros(n+1) R = np.zeros(n+1) V[1:] = xi R[1:] = (V[1:]*3/(4*math.pi))**(1/3) # Convert volume to radius # Calculate agglomeration terms if needed if type_flag == "agglomeration" or type_flag == "mix": # Get agglomeration frequency matrix F_M_tem = np.zeros((n+1,n+1)) F_M_tem = kernel_agg.calc_F_M_1D_jit(F_M_tem, COLEVAL, CORR_BETA, G, R, alpha_prim, SIZEEVAL, X_SEL, Y_SEL) F_M = F_M_tem[1:,1:] / V_unit # Calculate breakage terms if needed if type_flag == "breakage" or type_flag == "mix": B_F_intxk = np.zeros((m, n)) # Calculate breakage rates for each node B_R = np.zeros_like(V[1:]) B_R = kernel_break.calc_B_R_1d_jit(V[1:], B_R, pl_P1, pl_P2, G, BREAKRVAL) # Gauss-Legendre quadrature points and weights for integration xs1 = np.array([-9.681602395076260859e-01, -8.360311073266357695e-01, -6.133714327005903577e-01, -3.242534234038089158e-01, 0.000000000000000000e+00, 3.242534234038089158e-01, 6.133714327005903577e-01, 8.360311073266357695e-01, 9.681602395076260859e-01,]) ws1 = np.array([8.127438836157471758e-02, 1.806481606948571184e-01, 2.606106964029356599e-01, 3.123470770400028074e-01, 3.302393550012596712e-01, 3.123470770400028074e-01, 2.606106964029356599e-01, 1.806481606948571184e-01, 8.127438836157471758e-02]) # Calculate breakage daughter distribution integral for each moment and node for k in range(m): for i in range(n): argsk = (V[i+1],k,v,q,BREAKFVAL) func = kernel_break.breakage_func_1d_xk B_F_intxk[k, i] = kernel_break.gauss_legendre(func, 0.0, V[i+1], xs1, ws1, args=argsk) # Calculate moment derivatives for each moment order for k in range(m): dMdt_agg_ij = 0.0 dMdt_break_i = 0.0 for i in range(n): # Agglomeration contribution: birth - death if type_flag == "agglomeration" or type_flag == "mix": for j in range(n): dMdt_agg_ij += 0.5 * wi[i]*wi[j]*F_M[i, j]*((xi[i]+xi[j])**k-xi[i]**k-xi[j]**k) # Breakage contribution: birth - death if type_flag == "breakage" or type_flag == "mix": dMdt_break_i += wi[i] * B_R[i] * B_F_intxk[k, i] - wi[i] * xi[i]**k * B_R[i] # Total derivative for moment k dMdt[k] = dMdt_agg_ij + dMdt_break_i # Normalize derivatives dMdt_norm = dMdt / moments_norm_factor return dMdt_norm @njit def get_dMdt_2d(t, moments, n, indices, COLEVAL, CORR_BETA, G, alpha_prim, SIZEEVAL, V_unit, X_SEL, Y_SEL, pl_P1, pl_P2, pl_P3, pl_P4, BREAKRVAL, v, q, BREAKFVAL, type_flag): """ Calculate the moment derivatives for 2D population balance equations, handling agglomeration and/or breakage processes with bivariate distributions. Parameters: t (float): Current time. moments (array): Current moments of the bivariate distribution. n (int): Number of quadrature nodes per dimension. indices (array): 2D array of moment indices (i,j) corresponding to moments. COLEVAL (int): Case for collision kernel calculation. CORR_BETA (float): Correction term for collision frequency. G (float): Shear rate [1/s]. alpha_prim (array): Primary particle interaction parameters array. SIZEEVAL (int): Case for size dependency. V_unit (float): Unit volume used for concentration calculations. X_SEL (float): Size dependency parameter. Y_SEL (float): Size dependency parameter. pl_P1 (float): First parameter in power law for breakage rate. pl_P2 (float): Second parameter in power law for breakage rate. pl_P3 (float): Third parameter in power law for breakage rate. pl_P4 (float): Fourth parameter in power law for breakage rate. BREAKRVAL (int): Breakage rate model selector. v (float): Number of fragments in product function of power law. q (float): Parameter describing the breakage type in product function. BREAKFVAL (int): Breakage fragment distribution model selector. type_flag (str): Process type: "agglomeration", "breakage", or "mix". Returns: array: Moment derivatives (dM/dt). Raises: ValueError: If moments are not realizable. """ n0 = n dMdt = np.zeros(moments.shape) # Check if moments are realizable if moments[0] <= 0: raise ValueError("Wheeler: Moments are NOT realizable (moment[0] <= 0.0).") # Calculate conditional quadrature for 2D distribution xi, wi, n = qmom.calc_cqmom_2d(moments, n, indices, use_central=True) # for Debug # moments_cqmom = np.zeros_like(moments) # for idx, _ in enumerate(moments_cqmom): # indice = indices[idx] # moments_cqmom[idx] = chyqmom.quadrature_2d(wi, xi, indice) # print(np.mean(abs(moments_cqmom-moments)/moments)) # Print warning if number of nodes was reduced # if n0 > n: # print(f"Warning: At t = {t}, the moments are NOT realizable, abscissas reduced to {n}.") # Initialize volume arrays V = np.ones((n, n)) V1 = xi[0,:] # First component volumes V3 = xi[1,:] # Second component volumes V_flat = np.ones(n*n) # Initialize radius and composition arrays with padding R = np.ones((n+1, n+1)) X1 = np.ones((n+1, n+1)) # Composition fraction of first component X3 = np.ones((n+1, n+1)) # Composition fraction of second component # Calculate volume, radius, and composition for each node for i in range(n): for j in range(n): V[i,j] = V1[i*n+j] + V3[i*n+j] # Total volume V_flat[i*n+j] = V[i,j] X1[i+1,j+1] = V1[i*n+j] / V[i,j] # Composition fraction X3[i+1,j+1] = V3[i*n+j] / V[i,j] # Debug output for negative volumes # if np.any(V<0): # print(f"t = {t}") # print(f"xi = {xi}, wi = {wi}") # Convert volume to radius R[1:, 1:] = (V*3/(4*math.pi))**(1/3) # Calculate agglomeration terms if needed if type_flag == "agglomeration" or type_flag == "mix": # Get agglomeration frequency matrix for 2D F_M_tem = np.zeros((n+1,n+1,n+1,n+1)) F_M_tem = kernel_agg.calc_F_M_2D_jit(F_M_tem, COLEVAL, CORR_BETA, G, R, X1, X3, alpha_prim, SIZEEVAL, X_SEL, Y_SEL) F_M = F_M_tem[1:,1:,1:,1:] / V_unit # Calculate breakage terms if needed if type_flag == "breakage" or type_flag == "mix": # Initialize integral container for daughter distribution B_F_intxk = np.zeros((2*n, n, 2*n, n)) # Calculate breakage rates for each node B_R_flat = np.zeros_like(V_flat) B_R_flat = kernel_break.calc_B_R_2d_flat_jit(V_flat, B_R_flat, V1, V3, G, pl_P1, pl_P2, pl_P3, pl_P4, BREAKRVAL, BREAKFVAL) B_R = np.zeros((n,n)) for i in range(n): for j in range(n): # if V1[i] < eta or V3[j] < eta: # continue B_R[i,j] = B_R_flat[i*n+j] # Gauss-Legendre quadrature points and weights for integration xs1 = np.array([-9.681602395076260859e-01, -8.360311073266357695e-01, -6.133714327005903577e-01, -3.242534234038089158e-01, 0.000000000000000000e+00, 3.242534234038089158e-01, 6.133714327005903577e-01, 8.360311073266357695e-01, 9.681602395076260859e-01,]) ws1 = np.array([8.127438836157471758e-02, 1.806481606948571184e-01, 2.606106964029356599e-01, 3.123470770400028074e-01, 3.302393550012596712e-01, 3.123470770400028074e-01, 2.606106964029356599e-01, 1.806481606948571184e-01, 8.127438836157471758e-02]) xs3 = xs1 ws3 = ws1 # Calculate breakage daughter distribution integral for each moment and node for idx, _ in enumerate(dMdt): k = indices[idx,0] # First component moment order l = indices[idx,1] # Second component moment order for i in range(n): for j in range(n): # if V1[i*n+j] < eta or V3[i*n+j] < eta: # continue # Calculate bivariate daughter distribution integral argsk = (V1[i*n+j],V3[i*n+j],k,l,v,q,BREAKFVAL) func = kernel_break.breakage_func_2d_x1kx3l B_F_intxk[k, i, l, j] = kernel_break.dblgauss_legendre( func, 0.0, V1[i*n+j], 0.0, V3[i*n+j], xs1, ws1, xs3, ws3, args=argsk ) # argsk = (V1[i*n+j],V3[i*n+j],1,v,q,BREAKFVAL,eta) # func1 = kernel_break.breakage_func_2d_x1k_trunc # func2 = kernel_break.breakage_func_2d_x3k_trunc # norm_fac1 = kernel_break.dblgauss_legendre(func1, eta, V1[i*n+j], eta, V3[i*n+j], xs1, ws1, xs3,ws3,args=argsk) # norm_fac2 = kernel_break.dblgauss_legendre(func2, eta, V1[i*n+j], eta, V3[i*n+j], xs1, ws1, xs3,ws3,args=argsk) # argsk_trunc = (V1[i*n+j],V3[i*n+j],k,l,v,q,BREAKFVAL,eta) # func_norm = kernel_break.breakage_func_2d_trunc # B_F_intxk_trunk = kernel_break.dblgauss_legendre(func_norm, eta, V1[i*n+j], eta, V3[i*n+j], xs1, ws1, xs3,ws3,args=argsk_trunc) # B_F_intxk[k, i, l, j] = B_F_intxk_trunk / ((norm_fac1 + norm_fac2) / V[i,j]) # Calculate moment derivatives for each tracked moment for idx, _ in enumerate(dMdt): k = indices[idx,0] # First component moment order l = indices[idx,1] # Second component moment order dMdt_agg_ijab = 0.0 dMdt_break_ij = 0.0 for i in range(n): for j in range(n): # Agglomeration contribution: birth - death if type_flag == "agglomeration" or type_flag == "mix": for a in range(n): for b in range(n): dMdt_agg_ijab += 0.5 * wi[i*n+j]*wi[a*n+b]*F_M[i,j,a,b]*( (xi[0,i*n+j]+xi[0,a*n+b])**k*(xi[1,i*n+j]+xi[1,a*n+b])**l -xi[0,i*n+j]**k*xi[1,i*n+j]**l -xi[0,a*n+b]**k*xi[1,a*n+b]**l) # Breakage contribution: birth - death if type_flag == "breakage" or type_flag == "mix": dMdt_break_ij += (wi[i*n+j] * B_R[i,j] * B_F_intxk[k,i,l,j] - wi[i*n+j] * xi[0,i*n+j]**k * xi[1,i*n+j]**l * B_R[i,j]) # Total derivative for moment (k,l) dMdt[idx] = dMdt_agg_ijab + dMdt_break_ij return dMdt