IndisputableMonolith.Cost.Ndim.XCoordinates
IndisputableMonolith/Cost/Ndim/XCoordinates.lean · 154 lines · 15 declarations
show as:
view math explainer →
1import IndisputableMonolith.Cost.Ndim.Core
2
3/-!
4# Positive-coordinate Hessian formulas
5
6This module records the `x`-coordinate Hessian formulas for the
7multi-component reciprocal cost.
8
9The general entry formula is written in terms of the positive aggregate
10`R = aggregate α x`. We then specialize to the `2 × 2` case to obtain a
11closed determinant factorization, the zero-cost degeneracy statement,
12and a generic nondegeneracy criterion away from the neutral locus.
13-/
14
15namespace IndisputableMonolith
16namespace Cost
17namespace Ndim
18
19open scoped BigOperators
20open Matrix
21
22/-- The active `x`-coordinate direction `αᵢ / xᵢ`. -/
23noncomputable def xDirection {n : ℕ} (α x : Vec n) : Vec n :=
24 fun i => α i / x i
25
26/-- The diagonal correction term appearing in the `x`-coordinate Hessian. -/
27noncomputable def xDiagonalCorrection {n : ℕ} (α x : Vec n) (i j : Fin n) : ℝ :=
28 if i = j then α i / (x i) ^ 2 else 0
29
30/-- The `x`-coordinate Hessian entry of `JcostN`. -/
31noncomputable def xHessianEntry {n : ℕ} (α x : Vec n) (i j : Fin n) : ℝ :=
32 ((aggregate α x + (aggregate α x)⁻¹) / 2) * xDirection α x i * xDirection α x j
33 - ((aggregate α x - (aggregate α x)⁻¹) / 2) * xDiagonalCorrection α x i j
34
35/-- The full `x`-coordinate Hessian matrix. -/
36noncomputable def xHessianMatrix {n : ℕ} (α x : Vec n) : Fin n → Fin n → ℝ :=
37 fun i j => xHessianEntry α x i j
38
39theorem xHessianEntry_offDiag {n : ℕ} (α x : Vec n) {i j : Fin n} (hij : i ≠ j) :
40 xHessianEntry α x i j
41 = ((aggregate α x + (aggregate α x)⁻¹) / 2) * xDirection α x i * xDirection α x j := by
42 unfold xHessianEntry xDiagonalCorrection
43 simp [hij]
44
45theorem xHessianEntry_diag {n : ℕ} (α x : Vec n) (i : Fin n) :
46 xHessianEntry α x i i
47 = (α i / (2 * (x i) ^ 2))
48 * (((α i - 1) * aggregate α x) + ((α i + 1) * (aggregate α x)⁻¹)) := by
49 unfold xHessianEntry xDirection xDiagonalCorrection
50 simp
51 ring
52
53/-- On the zero-cost locus `aggregate α x = 1`, the `x`-Hessian collapses to
54the rank-one outer product of the active direction with itself. -/
55theorem xHessianEntry_zero_cost {n : ℕ} (α x : Vec n) {i j : Fin n}
56 (hR : aggregate α x = 1) :
57 xHessianEntry α x i j = xDirection α x i * xDirection α x j := by
58 unfold xHessianEntry xDirection xDiagonalCorrection
59 rw [hR]
60 by_cases hij : i = j
61 · simp [hij]
62 · simp [hij]
63
64/-- Two-component vectors written in coordinate order. -/
65abbrev vec2 (u v : ℝ) : Vec 2 := ![u, v]
66
67/-- The `2 × 2` positive-coordinate Hessian with an explicit aggregate
68parameter `R`. -/
69noncomputable def xHessianMatrix2OfR (a b x y R : ℝ) : Matrix (Fin 2) (Fin 2) ℝ :=
70 !![
71 (a / (2 * x ^ 2)) * (((a - 1) * R) + ((a + 1) * R⁻¹)),
72 ((a * b) / (2 * x * y)) * (R + R⁻¹);
73 ((a * b) / (2 * x * y)) * (R + R⁻¹),
74 (b / (2 * y ^ 2)) * (((b - 1) * R) + ((b + 1) * R⁻¹))
75 ]
76
77/-- The actual `2 × 2` `x`-coordinate Hessian, with `R` specialized to the
78weighted aggregate. -/
79noncomputable def xHessianMatrix2 (a b x y : ℝ) : Matrix (Fin 2) (Fin 2) ℝ :=
80 xHessianMatrix2OfR a b x y (aggregate (vec2 a b) (vec2 x y))
81
82theorem xHessianMatrix2_eq_general (a b x y : ℝ) :
83 xHessianMatrix2 a b x y
84 = fun i j => xHessianEntry (vec2 a b) (vec2 x y) i j := by
85 ext i j
86 fin_cases i <;> fin_cases j <;>
87 simp [xHessianMatrix2, xHessianMatrix2OfR, xHessianEntry,
88 xDirection, xDiagonalCorrection, vec2]
89 all_goals ring
90
91theorem det_xHessianMatrix2OfR_formula (a b x y R : ℝ)
92 (hx : x ≠ 0) (hy : y ≠ 0) (hR : R ≠ 0) :
93 Matrix.det (xHessianMatrix2OfR a b x y R)
94 = -(a * b * (R - 1) * (R + 1) * (R ^ 2 * a + R ^ 2 * b - R ^ 2 + a + b + 1))
95 / (4 * R ^ 2 * x ^ 2 * y ^ 2) := by
96 simp [xHessianMatrix2OfR, Matrix.det_fin_two]
97 field_simp [hx, hy, hR]
98 ring
99
100theorem det_xHessianMatrix2_formula (a b x y : ℝ)
101 (hx : x ≠ 0) (hy : y ≠ 0) :
102 let R := aggregate (vec2 a b) (vec2 x y)
103 Matrix.det (xHessianMatrix2 a b x y)
104 = -(a * b * (R - 1) * (R + 1) * (R ^ 2 * a + R ^ 2 * b - R ^ 2 + a + b + 1))
105 / (4 * R ^ 2 * x ^ 2 * y ^ 2) := by
106 dsimp [xHessianMatrix2]
107 simpa using det_xHessianMatrix2OfR_formula a b x y (aggregate (vec2 a b) (vec2 x y))
108 hx hy (aggregate_pos (vec2 a b) (vec2 x y)).ne'
109
110/-- The neutral locus `aggregate = 1` is a degeneracy locus in the `2 × 2`
111model. -/
112theorem det_xHessianMatrix2_zero_cost (a b x y : ℝ)
113 (hx : x ≠ 0) (hy : y ≠ 0)
114 (hR : aggregate (vec2 a b) (vec2 x y) = 1) :
115 Matrix.det (xHessianMatrix2 a b x y) = 0 := by
116 rw [det_xHessianMatrix2_formula a b x y hx hy]
117 simp [hR]
118
119/-- Away from the neutral locus and the secondary discriminant factor, the
120`2 × 2` `x`-coordinate Hessian is nondegenerate. -/
121theorem det_xHessianMatrix2_ne_zero_of_generic (a b x y : ℝ)
122 (hx : x ≠ 0) (hy : y ≠ 0)
123 (ha : a ≠ 0) (hb : b ≠ 0)
124 (hR1 : aggregate (vec2 a b) (vec2 x y) ≠ 1)
125 (hdisc :
126 (aggregate (vec2 a b) (vec2 x y)) ^ 2 * a
127 + (aggregate (vec2 a b) (vec2 x y)) ^ 2 * b
128 - (aggregate (vec2 a b) (vec2 x y)) ^ 2
129 + a + b + 1 ≠ 0) :
130 Matrix.det (xHessianMatrix2 a b x y) ≠ 0 := by
131 let R := aggregate (vec2 a b) (vec2 x y)
132 have hR : R ≠ 0 := (aggregate_pos (vec2 a b) (vec2 x y)).ne'
133 have hRp1 : R + 1 ≠ 0 := by
134 have hpos : 0 < R := by simp [R]
135 linarith
136 have hden : 4 * R ^ 2 * x ^ 2 * y ^ 2 ≠ 0 := by
137 have hR2 : R ^ 2 ≠ 0 := pow_ne_zero 2 hR
138 have hx2 : x ^ 2 ≠ 0 := pow_ne_zero 2 hx
139 have hy2 : y ^ 2 ≠ 0 := pow_ne_zero 2 hy
140 have h4R : 4 * R ^ 2 ≠ 0 := mul_ne_zero (by norm_num) hR2
141 have h4Rx : 4 * R ^ 2 * x ^ 2 ≠ 0 := mul_ne_zero h4R hx2
142 exact mul_ne_zero h4Rx hy2
143 rw [det_xHessianMatrix2_formula a b x y hx hy]
144 refine div_ne_zero ?_ hden
145 refine neg_ne_zero.mpr ?_
146 refine mul_ne_zero ?_ hdisc
147 refine mul_ne_zero ?_ hRp1
148 refine mul_ne_zero ?_ (sub_ne_zero.mpr hR1)
149 exact mul_ne_zero ha hb
150
151end Ndim
152end Cost
153end IndisputableMonolith
154