多项式回归

1:修改了返回值类型
2:修改了多项式计算公式
master
yangzhe123 2025-11-17 13:59:37 +08:00
parent bc008b0174
commit 74000d3ecb
4 changed files with 198 additions and 122 deletions

View File

@ -6,6 +6,7 @@ import com.gunshi.project.hsz.entity.dto.ProjectSafeCalculateDto;
import com.gunshi.project.hsz.entity.dto.ProjectSaveReportDto; import com.gunshi.project.hsz.entity.dto.ProjectSaveReportDto;
import com.gunshi.project.hsz.entity.dto.SyRegressionDataDto; import com.gunshi.project.hsz.entity.dto.SyRegressionDataDto;
import com.gunshi.project.hsz.entity.vo.ForecastResultVo; import com.gunshi.project.hsz.entity.vo.ForecastResultVo;
import com.gunshi.project.hsz.entity.vo.ProjectSafeCalculateVo;
import com.gunshi.project.hsz.model.ForecastTask; import com.gunshi.project.hsz.model.ForecastTask;
import com.gunshi.project.hsz.model.RegressionEquation; import com.gunshi.project.hsz.model.RegressionEquation;
import com.gunshi.project.hsz.model.SyRegressionData; import com.gunshi.project.hsz.model.SyRegressionData;
@ -40,8 +41,8 @@ public class ProjectSafeAnalyseController {
@Operation(summary = "多项式回归") @Operation(summary = "多项式回归")
@PostMapping("/caculate") @PostMapping("/caculate")
public R<Map<String, RegressionEquation>> calculate(@RequestBody ProjectSafeCalculateDto dto){ public R<ProjectSafeCalculateVo> calculate(@RequestBody ProjectSafeCalculateDto dto){
Map<String, RegressionEquation> ans = jcskSyRService.calculate(dto); ProjectSafeCalculateVo ans = jcskSyRService.calculate(dto);
return R.ok(ans); return R.ok(ans);
} }

View File

@ -0,0 +1,21 @@
package com.gunshi.project.hsz.entity.vo;
import com.gunshi.project.hsz.common.model.vo.OsmoticPressDetailVo;
import lombok.Data;
import java.util.List;
@Data
public class ProjectSafeCalculateVo {
private String one;
private String two;
private String three;
private String four;
private List<OsmoticPressDetailVo> datas;
}

View File

@ -26,6 +26,7 @@ import com.gunshi.project.hsz.util.*;
import com.ruoyi.common.utils.StringUtils; import com.ruoyi.common.utils.StringUtils;
import jakarta.servlet.http.HttpServletResponse; import jakarta.servlet.http.HttpServletResponse;
import lombok.extern.slf4j.Slf4j; import lombok.extern.slf4j.Slf4j;
import org.jetbrains.annotations.NotNull;
import org.springframework.beans.BeanUtils; import org.springframework.beans.BeanUtils;
import org.springframework.beans.factory.annotation.Autowired; import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.stereotype.Service; import org.springframework.stereotype.Service;
@ -790,9 +791,9 @@ public class JcskSyRService extends ServiceImpl<JcskSyRMapper, JcskSyR> {
* @param dto * @param dto
* @return * @return
*/ */
public Map<String, RegressionEquation> calculate(ProjectSafeCalculateDto dto) { public ProjectSafeCalculateVo calculate(@NotNull ProjectSafeCalculateDto dto) {
LambdaQueryWrapper<JcskSyB> queryWrapper = new LambdaQueryWrapper<>(); LambdaQueryWrapper<JcskSyB> queryWrapper = new LambdaQueryWrapper<>();
Map<String, RegressionEquation> res = new HashMap<>(); ProjectSafeCalculateVo res = new ProjectSafeCalculateVo();
//根据dvcd查询stcd和mpcd //根据dvcd查询stcd和mpcd
queryWrapper.eq(JcskSyB::getDvcd, dto.getDvcd()); queryWrapper.eq(JcskSyB::getDvcd, dto.getDvcd());
JcskSyB jcskSyB = jcskSyBService.getBaseMapper().selectOne(queryWrapper); JcskSyB jcskSyB = jcskSyBService.getBaseMapper().selectOne(queryWrapper);
@ -835,13 +836,22 @@ public class JcskSyRService extends ServiceImpl<JcskSyRMapper, JcskSyR> {
return res; return res;
} }
RegressionEquation first = RegressionAnalysis.calculateLinear(data); RegressionEquation first = RegressionAnalysis.calculateLinear(data);
if(first != null){
res.setOne(first.toString());
}
RegressionEquation second = RegressionAnalysis.calculateQuadratic(data); RegressionEquation second = RegressionAnalysis.calculateQuadratic(data);
if(second != null){
res.setTwo(second.toString());
}
RegressionEquation three = RegressionAnalysis.calculateCubic(data); RegressionEquation three = RegressionAnalysis.calculateCubic(data);
if(three != null){
res.setThree(three.toString());
}
RegressionEquation four = RegressionAnalysis.calculateQuartic(data); RegressionEquation four = RegressionAnalysis.calculateQuartic(data);
res.put("first", first); if(four != null){
res.put("second", second); res.setFour(four.toString());
res.put("three", three); }
res.put("four", four); res.setDatas(data);
return res; return res;
} }

View File

@ -163,179 +163,166 @@ public class RegressionAnalysis {
return new RegressionEquation(2, coefficients, rSquared, n); return new RegressionEquation(2, coefficients, rSquared, n);
} }
/** /**
* 3- * 3-
*/ */
public static RegressionEquation calculateCubic(List<OsmoticPressDetailVo> data) { public static RegressionEquation calculateCubic(List<OsmoticPressDetailVo> data) {
int n = 0;
List<BigDecimal> coefficientsList = null;
BigDecimal rSquared = null;
try { try {
List<OsmoticPressDetailVo> validData = filterValidData(data); List<OsmoticPressDetailVo> validData = filterValidData(data);
if (validData.size() < 4) { if (validData.size() < 4) {
throw new IllegalArgumentException("三次回归至少需要4个数据点"); throw new IllegalArgumentException("三次回归至少需要4个数据点");
} }
n = validData.size(); int n = validData.size();
// 使用中心化数据提高数值稳定性 // 使用高精度计算各项和
BigDecimal meanX = calculateMean(validData, OsmoticPressDetailVo::getRz); MathContext mc = new MathContext(200, RoundingMode.HALF_UP); // 提高精度到200位
// 计算各项和(使用中心化数据)
BigDecimal sumX = BigDecimal.ZERO, sumY = BigDecimal.ZERO; BigDecimal sumX = BigDecimal.ZERO, sumY = BigDecimal.ZERO;
BigDecimal sumX2 = BigDecimal.ZERO, sumX3 = BigDecimal.ZERO, sumX4 = BigDecimal.ZERO; BigDecimal sumX2 = BigDecimal.ZERO, sumX3 = BigDecimal.ZERO, sumX4 = BigDecimal.ZERO;
BigDecimal sumX5 = BigDecimal.ZERO, sumX6 = BigDecimal.ZERO; BigDecimal sumX5 = BigDecimal.ZERO, sumX6 = BigDecimal.ZERO;
BigDecimal sumXY = BigDecimal.ZERO, sumX2Y = BigDecimal.ZERO, sumX3Y = BigDecimal.ZERO; BigDecimal sumXY = BigDecimal.ZERO, sumX2Y = BigDecimal.ZERO, sumX3Y = BigDecimal.ZERO;
for (OsmoticPressDetailVo vo : validData) { for (OsmoticPressDetailVo vo : validData) {
BigDecimal x = vo.getRz().subtract(meanX); // 中心化 BigDecimal x = vo.getRz();
BigDecimal y = vo.getValue(); BigDecimal y = vo.getValue();
BigDecimal x2 = x.multiply(x); BigDecimal x2 = x.multiply(x, mc);
BigDecimal x3 = x2.multiply(x); BigDecimal x3 = x2.multiply(x, mc);
BigDecimal x4 = x3.multiply(x); BigDecimal x4 = x3.multiply(x, mc);
BigDecimal x5 = x4.multiply(x); BigDecimal x5 = x4.multiply(x, mc);
BigDecimal x6 = x5.multiply(x); BigDecimal x6 = x5.multiply(x, mc);
sumX = sumX.add(x); sumX = sumX.add(x, mc);
sumY = sumY.add(y); sumY = sumY.add(y, mc);
sumX2 = sumX2.add(x2); sumX2 = sumX2.add(x2, mc);
sumX3 = sumX3.add(x3); sumX3 = sumX3.add(x3, mc);
sumX4 = sumX4.add(x4); sumX4 = sumX4.add(x4, mc);
sumX5 = sumX5.add(x5); sumX5 = sumX5.add(x5, mc);
sumX6 = sumX6.add(x6); sumX6 = sumX6.add(x6, mc);
sumXY = sumXY.add(x.multiply(y)); sumXY = sumXY.add(x.multiply(y, mc), mc);
sumX2Y = sumX2Y.add(x2.multiply(y)); sumX2Y = sumX2Y.add(x2.multiply(y, mc), mc);
sumX3Y = sumX3Y.add(x3.multiply(y)); sumX3Y = sumX3Y.add(x3.multiply(y, mc), mc);
} }
BigDecimal nBig = new BigDecimal(n); BigDecimal nBig = new BigDecimal(n);
// 构建正规方程组(中心化后的矩阵条件数更好) // 构建正规方程组 - 使用中心化数据提高数值稳定性
BigDecimal meanX = sumX.divide(nBig, mc);
// 重新计算中心化后的各阶矩
BigDecimal m2 = BigDecimal.ZERO, m3 = BigDecimal.ZERO, m4 = BigDecimal.ZERO, m5 = BigDecimal.ZERO, m6 = BigDecimal.ZERO;
BigDecimal m1y = BigDecimal.ZERO, m2y = BigDecimal.ZERO, m3y = BigDecimal.ZERO;
for (OsmoticPressDetailVo vo : validData) {
BigDecimal xCentered = vo.getRz().subtract(meanX, mc);
BigDecimal y = vo.getValue();
BigDecimal x2 = xCentered.multiply(xCentered, mc);
BigDecimal x3 = x2.multiply(xCentered, mc);
BigDecimal x4 = x3.multiply(xCentered, mc);
BigDecimal x5 = x4.multiply(xCentered, mc);
BigDecimal x6 = x5.multiply(xCentered, mc);
m2 = m2.add(x2, mc);
m3 = m3.add(x3, mc);
m4 = m4.add(x4, mc);
m5 = m5.add(x5, mc);
m6 = m6.add(x6, mc);
m1y = m1y.add(xCentered.multiply(y, mc), mc);
m2y = m2y.add(x2.multiply(y, mc), mc);
m3y = m3y.add(x3.multiply(y, mc), mc);
}
// 中心化后的正规方程组常数项为0
BigDecimal[][] matrix = { BigDecimal[][] matrix = {
{nBig, sumX, sumX2, sumX3}, {nBig, BigDecimal.ZERO, m2, m3},
{sumX, sumX2, sumX3, sumX4}, {BigDecimal.ZERO, m2, m3, m4},
{sumX2, sumX3, sumX4, sumX5}, {m2, m3, m4, m5},
{sumX3, sumX4, sumX5, sumX6} {m3, m4, m5, m6}
}; };
BigDecimal[] vector = {sumY, sumXY, sumX2Y, sumX3Y}; // 注意第一个方程对应常数项右侧是sumY
BigDecimal[] vector = {
// 使用改进的高斯消元法 sumY, // 对应常数项
BigDecimal[] centeredCoefficients = solveLinearSystemImproved(matrix, vector); m1y, // 对应一次项
m2y, // 对应二次项
m3y // 对应三次项
};
// 使用高精度求解
BigDecimal[] centeredCoefficients = solveLinearSystemHighPrecision(matrix, vector, mc);
if (centeredCoefficients == null) { if (centeredCoefficients == null) {
throw new IllegalArgumentException("无法求解三次回归方程"); throw new IllegalArgumentException("无法求解三次回归方程");
} }
// 将中心化系数转换回原始坐标 // 将中心化系数转换回原始坐标
BigDecimal[] originalCoefficients = convertToOriginalCoefficients(centeredCoefficients, meanX); BigDecimal[] originalCoefficients = convertCenteredToOriginalCubic(centeredCoefficients, meanX, mc);
// 计算R² // 计算R²
BigDecimal meanY = sumY.divide(nBig, MathContext.DECIMAL128); BigDecimal meanY = sumY.divide(nBig, mc);
coefficientsList = Arrays.asList(originalCoefficients); List<BigDecimal> coefficientsList = Arrays.asList(originalCoefficients);
rSquared = calculateRSquared(validData, coefficientsList, meanY); BigDecimal rSquared = calculateRSquared(validData, coefficientsList, meanY);
return new RegressionEquation(3, coefficientsList, rSquared, n);
} catch (IllegalArgumentException e) { } catch (IllegalArgumentException e) {
e.printStackTrace(); e.printStackTrace();
return null; return null;
} }
return new RegressionEquation(3, coefficientsList, rSquared, n);
} }
/** /**
* * 线
*/ */
private static BigDecimal[] convertToOriginalCoefficients(BigDecimal[] centeredCoeffs, BigDecimal meanX) { private static BigDecimal[] solveLinearSystemHighPrecision(BigDecimal[][] matrix, BigDecimal[] vector, MathContext mc) {
// 对于三次多项式y = a + b(x - μ) + c(x - μ)² + d(x - μ)³
// 展开后y = (a - bμ + cμ² - dμ³) + (b - 2cμ + 3dμ²)x + (c - 3dμ)x² + d x³
BigDecimal a = centeredCoeffs[0];
BigDecimal b = centeredCoeffs[1];
BigDecimal c = centeredCoeffs[2];
BigDecimal d = centeredCoeffs[3];
BigDecimal mu = meanX;
BigDecimal mu2 = mu.multiply(mu);
BigDecimal mu3 = mu2.multiply(mu);
BigDecimal newA = a.subtract(b.multiply(mu))
.add(c.multiply(mu2))
.subtract(d.multiply(mu3));
BigDecimal newB = b.subtract(BigDecimal.valueOf(2).multiply(c).multiply(mu))
.add(BigDecimal.valueOf(3).multiply(d).multiply(mu2));
BigDecimal newC = c.subtract(BigDecimal.valueOf(3).multiply(d).multiply(mu));
BigDecimal newD = d;
return new BigDecimal[]{newA, newB, newC, newD};
}
/**
*
*/
private static BigDecimal[] solveLinearSystemWithPivot(BigDecimal[][] matrix, BigDecimal[] vector) {
int n = vector.length; int n = vector.length;
// 复制矩阵和向量
BigDecimal[][] a = new BigDecimal[n][n]; BigDecimal[][] a = new BigDecimal[n][n];
BigDecimal[] b = new BigDecimal[n]; BigDecimal[] b = new BigDecimal[n];
for (int i = 0; i < n; i++) { for (int i = 0; i < n; i++) {
System.arraycopy(matrix[i], 0, a[i], 0, n); a[i] = matrix[i].clone();
b[i] = vector[i]; b[i] = vector[i];
} }
int[] rowPerm = new int[n]; // 部分主元高斯消元
int[] colPerm = new int[n];
for (int i = 0; i < n; i++) {
rowPerm[i] = i;
colPerm[i] = i;
}
// 完全主元高斯消元
for (int k = 0; k < n; k++) { for (int k = 0; k < n; k++) {
// 寻找主元 // 寻找主元
int maxRow = k, maxCol = k; int maxRow = k;
BigDecimal maxVal = a[rowPerm[k]][colPerm[k]].abs(); BigDecimal maxVal = a[k][k].abs();
for (int i = k + 1; i < n; i++) {
for (int i = k; i < n; i++) { BigDecimal current = a[i][k].abs();
for (int j = k; j < n; j++) { if (current.compareTo(maxVal) > 0) {
BigDecimal current = a[rowPerm[i]][colPerm[j]].abs(); maxVal = current;
if (current.compareTo(maxVal) > 0) { maxRow = i;
maxVal = current;
maxRow = i;
maxCol = j;
}
} }
} }
// 交换行 // 交换行
if (maxRow != k) { if (maxRow != k) {
int temp = rowPerm[k]; BigDecimal[] tempRow = a[k];
rowPerm[k] = rowPerm[maxRow]; a[k] = a[maxRow];
rowPerm[maxRow] = temp; a[maxRow] = tempRow;
BigDecimal tempB = b[k];
b[k] = b[maxRow];
b[maxRow] = tempB;
} }
// 交换列 // 检查主元是否为0 - 增加容错
if (maxCol != k) { if (a[k][k].abs().compareTo(new BigDecimal("1E-100")) < 0) {
int temp = colPerm[k]; // 尝试使用伪逆或特殊处理
colPerm[k] = colPerm[maxCol]; return solveSingularSystem(a, b, mc);
colPerm[maxCol] = temp;
}
// 主元为0矩阵奇异
if (a[rowPerm[k]][colPerm[k]].compareTo(BigDecimal.ZERO) == 0) {
return null;
} }
// 消元 // 消元
for (int i = k + 1; i < n; i++) { for (int i = k + 1; i < n; i++) {
BigDecimal factor = a[rowPerm[i]][colPerm[k]].divide(a[rowPerm[k]][colPerm[k]], 100, RoundingMode.HALF_UP); BigDecimal factor = a[i][k].divide(a[k][k], mc);
for (int j = k; j < n; j++) { for (int j = k; j < n; j++) {
a[rowPerm[i]][colPerm[j]] = a[rowPerm[i]][colPerm[j]].subtract( a[i][j] = a[i][j].subtract(factor.multiply(a[k][j], mc), mc);
factor.multiply(a[rowPerm[k]][colPerm[j]]));
} }
b[rowPerm[i]] = b[rowPerm[i]].subtract(factor.multiply(b[rowPerm[k]])); b[i] = b[i].subtract(factor.multiply(b[k], mc), mc);
} }
} }
@ -344,18 +331,75 @@ public class RegressionAnalysis {
for (int i = n - 1; i >= 0; i--) { for (int i = n - 1; i >= 0; i--) {
BigDecimal sum = BigDecimal.ZERO; BigDecimal sum = BigDecimal.ZERO;
for (int j = i + 1; j < n; j++) { for (int j = i + 1; j < n; j++) {
sum = sum.add(a[rowPerm[i]][colPerm[j]].multiply(x[colPerm[j]])); sum = sum.add(a[i][j].multiply(x[j], mc), mc);
} }
x[colPerm[i]] = b[rowPerm[i]].subtract(sum).divide(a[rowPerm[i]][colPerm[i]], 100, RoundingMode.HALF_UP); x[i] = b[i].subtract(sum, mc).divide(a[i][i], mc);
} }
// 恢复原始顺序 return x;
BigDecimal[] result = new BigDecimal[n]; }
for (int i = 0; i < n; i++) {
result[i] = x[i]; /**
*
*/
private static BigDecimal[] solveSingularSystem(BigDecimal[][] a, BigDecimal[] b, MathContext mc) {
int n = b.length;
// 简单处理将接近0的主元设为一个小值
for (int k = 0; k < n; k++) {
if (a[k][k].abs().compareTo(new BigDecimal("1E-100")) < 0) {
a[k][k] = new BigDecimal("1E-50"); // 设置一个小值
}
for (int i = k + 1; i < n; i++) {
BigDecimal factor = a[i][k].divide(a[k][k], mc);
for (int j = k; j < n; j++) {
a[i][j] = a[i][j].subtract(factor.multiply(a[k][j], mc), mc);
}
b[i] = b[i].subtract(factor.multiply(b[k], mc), mc);
}
} }
return result; // 回代
BigDecimal[] x = new BigDecimal[n];
for (int i = n - 1; i >= 0; i--) {
BigDecimal sum = BigDecimal.ZERO;
for (int j = i + 1; j < n; j++) {
sum = sum.add(a[i][j].multiply(x[j], mc), mc);
}
x[i] = b[i].subtract(sum, mc).divide(a[i][i], mc);
}
return x;
}
/**
*
*/
private static BigDecimal[] convertCenteredToOriginalCubic(BigDecimal[] centeredCoeffs, BigDecimal mean, MathContext mc) {
// centered: y = a + b(x-μ) + c(x-μ)² + d(x-μ)³
// original: y = A + Bx + Cx² + Dx³
BigDecimal a = centeredCoeffs[0];
BigDecimal b = centeredCoeffs[1];
BigDecimal c = centeredCoeffs[2];
BigDecimal d = centeredCoeffs[3];
BigDecimal mu = mean;
BigDecimal mu2 = mu.multiply(mu, mc);
BigDecimal mu3 = mu2.multiply(mu, mc);
BigDecimal A = a.subtract(b.multiply(mu, mc), mc)
.add(c.multiply(mu2, mc), mc)
.subtract(d.multiply(mu3, mc), mc);
BigDecimal B = b.subtract(BigDecimal.valueOf(2).multiply(c, mc).multiply(mu, mc), mc)
.add(BigDecimal.valueOf(3).multiply(d, mc).multiply(mu2, mc), mc);
BigDecimal C = c.subtract(BigDecimal.valueOf(3).multiply(d, mc).multiply(mu, mc), mc);
BigDecimal D = d;
return new BigDecimal[]{A, B, C, D};
} }