1use crate::array::*;
2use crate::category::*;
3use crate::finite_function::*;
4use crate::indexed_coproduct::*;
5use crate::operations::*;
6use crate::semifinite::*;
7
8use crate::strict::open_hypergraph::*;
9
10use super::traits::*;
11
12use core::fmt::Debug;
13use num_traits::One;
14
15type ResidualFn<K, O1, A1, O2> =
16 dyn Fn(&Operations<K, O1, A1>) -> IndexedCoproduct<K, SemifiniteFunction<K, O2>>;
17
18pub struct Optic<
20 F: Functor<K, O1, A1, O2, A2>,
21 R: Functor<K, O1, A1, O2, A2>,
22 K: ArrayKind,
23 O1,
24 A1,
25 O2,
26 A2,
27> {
28 pub fwd: F,
29 pub rev: R,
30 pub residual: Box<ResidualFn<K, O1, A1, O2>>,
31 _phantom: std::marker::PhantomData<A2>,
32}
33
34impl<
35 F: Functor<K, O1, A1, O2, A2>,
36 R: Functor<K, O1, A1, O2, A2>,
37 K: ArrayKind,
38 O1,
39 A1,
40 O2,
41 A2,
42 > Optic<F, R, K, O1, A1, O2, A2>
43{
44 pub fn new(
45 fwd: F,
46 rev: R,
47 residual: Box<ResidualFn<K, O1, A1, O2>>,
48 ) -> Optic<F, R, K, O1, A1, O2, A2> {
49 let _phantom = std::marker::PhantomData;
50 Optic {
51 fwd,
52 rev,
53 residual,
54 _phantom,
55 }
56 }
57}
58
59impl<F, R, K: ArrayKind + Debug, O1, A1, O2, A2> Functor<K, O1, A1, O2, A2>
60 for Optic<F, R, K, O1, A1, O2, A2>
61where
62 F: Functor<K, O1, A1, O2, A2>,
63 R: Functor<K, O1, A1, O2, A2>,
64 K::Type<K::I>: NaturalArray<K>,
65 K::Type<O1>: Array<K, O1> + PartialEq,
66 K::Type<A1>: Array<K, A1>,
67 K::Type<O2>: Array<K, O2> + PartialEq + Debug,
68 K::Type<A2>: Array<K, A2>,
69{
70 fn map_object(
71 &self,
72 a: &SemifiniteFunction<K, O1>,
73 ) -> IndexedCoproduct<K, SemifiniteFunction<K, O2>> {
74 let fa = self.fwd.map_object(a);
76 let ra = self.rev.map_object(a);
77
78 assert_eq!(fa.len(), ra.len());
79 let n = fa.len();
80
81 let paired = fa
83 .coproduct(&ra)
84 .expect("Coproduct of SemifiniteFunction always succeeds");
85 let p = FiniteFunction::transpose(K::I::one() + K::I::one(), n);
86
87 let sources = FiniteFunction::new(
88 fa.sources.table + ra.sources.table,
89 fa.sources.target + ra.sources.target - K::I::one(),
90 )
91 .unwrap();
92 let values = paired.indexed_values(&p).unwrap();
93 IndexedCoproduct::new(sources, values).unwrap()
94 }
95
96 fn map_operations(&self, ops: Operations<K, O1, A1>) -> OpenHypergraph<K, O2, A2> {
97 let fwd = self.fwd.map_operations(ops.clone());
99 let rev = self.rev.map_operations(ops.clone());
100
101 let fa = self.fwd.map_object(&ops.a.values);
103 let fb = self.fwd.map_object(&ops.b.values);
104 let ra = self.rev.map_object(&ops.a.values);
105 let rb = self.rev.map_object(&ops.b.values);
106
107 let m = (self.residual)(&ops);
108
109 let fwd_interleave = interleave_blocks(&ops.b.flatmap_sources(&fb), &m).dagger();
111 let rev_cointerleave = interleave_blocks(&m, &ops.b.flatmap_sources(&rb));
112
113 debug_assert_eq!(fwd.target(), fwd_interleave.source());
114 debug_assert_eq!(rev_cointerleave.target(), rev.source());
115
116 let i_fb = OpenHypergraph::identity(fb.values.clone());
117 let i_rb = OpenHypergraph::identity(rb.values.clone());
118
119 let lhs = fwd.compose(&fwd_interleave).unwrap().tensor(&i_rb);
121 let rhs = i_fb.tensor(&rev_cointerleave.compose(&rev).unwrap());
122 let c = lhs.compose(&rhs).unwrap();
123
124 let d = partial_dagger(&c, &fa, &fb, &ra, &rb);
126
127 let lhs = interleave_blocks(&fa, &ra).dagger();
129 let rhs = interleave_blocks(&fb, &rb);
130
131 lhs.compose(&d).unwrap().compose(&rhs).unwrap()
132 }
133
134 fn map_arrow(&self, f: &OpenHypergraph<K, O1, A1>) -> OpenHypergraph<K, O2, A2> {
135 define_map_arrow(self, f)
136 }
137}
138
139impl<F, R, K: ArrayKind + Debug, O1, A1, O2, A2> Optic<F, R, K, O1, A1, O2, A2>
140where
141 F: Functor<K, O1, A1, O2, A2>,
142 R: Functor<K, O1, A1, O2, A2>,
143 K::Type<K::I>: NaturalArray<K>,
144 K::Type<O1>: Array<K, O1> + PartialEq,
145 K::Type<A1>: Array<K, A1>,
146 K::Type<O2>: Array<K, O2> + PartialEq + Debug,
147 K::Type<A2>: Array<K, A2>,
148{
149 pub fn adapt(
150 &self,
151 c: &OpenHypergraph<K, O2, A2>,
152 a: &SemifiniteFunction<K, O1>,
153 b: &SemifiniteFunction<K, O1>,
154 ) -> OpenHypergraph<K, O2, A2> {
155 let fa = self.fwd.map_object(a);
156 let fb = self.fwd.map_object(b);
157 let ra = self.rev.map_object(a);
158 let rb = self.rev.map_object(b);
159
160 let lhs = interleave_blocks(&fa, &ra);
162 let rhs = interleave_blocks(&fb, &rb).dagger();
163 let d = lhs.compose(c).unwrap().compose(&rhs).unwrap();
164
165 debug_assert_eq!(d.source(), fa.coproduct(&ra).unwrap().values);
168 debug_assert_eq!(d.target(), fb.coproduct(&rb).unwrap().values);
169
170 partial_dagger(&d, &fa, &fb, &rb, &ra)
172 }
173}
174
175fn interleave_blocks<K: ArrayKind, O, A>(
176 a: &IndexedCoproduct<K, SemifiniteFunction<K, O>>,
177 b: &IndexedCoproduct<K, SemifiniteFunction<K, O>>,
178) -> OpenHypergraph<K, O, A>
179where
180 K::Type<K::I>: NaturalArray<K>,
181 K::Type<O>: Array<K, O> + PartialEq,
182 K::Type<A>: Array<K, A>,
183{
184 if a.len() != b.len() {
185 panic!("Can't interleave types of unequal lengths");
186 }
187
188 let ab = a
189 .coproduct(b)
190 .expect("Coproduct of SemifiniteFunction always succeeds");
191
192 let two = K::I::one() + K::I::one();
193 let s = FiniteFunction::identity(ab.values.len());
194 let t = ab
195 .sources
196 .injections(&FiniteFunction::transpose(two, a.len()))
197 .unwrap();
198
199 OpenHypergraph::spider(s, t, ab.values.clone()).unwrap()
200}
201
202fn partial_dagger<K: ArrayKind + Debug, O, A>(
204 c: &OpenHypergraph<K, O, A>,
205 fa: &IndexedCoproduct<K, SemifiniteFunction<K, O>>,
206 fb: &IndexedCoproduct<K, SemifiniteFunction<K, O>>,
207 ra: &IndexedCoproduct<K, SemifiniteFunction<K, O>>,
208 rb: &IndexedCoproduct<K, SemifiniteFunction<K, O>>,
209) -> OpenHypergraph<K, O, A>
210where
211 K::Type<K::I>: NaturalArray<K>,
212 K::Type<O>: Array<K, O>,
213 K::Type<A>: Array<K, A>,
214{
215 let s = {
216 let s_i = FiniteFunction::inj0(fa.values.len(), rb.values.len())
217 .compose(&c.s)
218 .unwrap();
219
220 let s_o = FiniteFunction::inj1(fb.values.len(), ra.values.len())
221 .compose(&c.t)
222 .unwrap();
223
224 s_i.coproduct(&s_o).unwrap()
225 };
226
227 let t = {
228 let t_i = FiniteFunction::inj0(fb.values.len(), ra.values.len())
229 .compose(&c.t)
230 .unwrap();
231 let t_o = FiniteFunction::inj1(fa.values.len(), rb.values.len())
232 .compose(&c.s)
233 .unwrap();
234
235 t_i.coproduct(&t_o).unwrap()
236 };
237
238 OpenHypergraph::new(s, t, c.h.clone()).unwrap()
239}