open_hypergraphs/strict/functor/
optic.rs

1use crate::array::*;
2use crate::category::*;
3use crate::finite_function::*;
4use crate::indexed_coproduct::*;
5use crate::operations::*;
6use crate::semifinite::*;
7
8use crate::strict::open_hypergraph::*;
9
10use super::traits::*;
11
12use core::fmt::Debug;
13use num_traits::One;
14
15type ResidualFn<K, O1, A1, O2> =
16    dyn Fn(&Operations<K, O1, A1>) -> IndexedCoproduct<K, SemifiniteFunction<K, O2>>;
17
18/// An optic is composed of forward and reverse functors along with a residual object
19pub struct Optic<
20    F: Functor<K, O1, A1, O2, A2>,
21    R: Functor<K, O1, A1, O2, A2>,
22    K: ArrayKind,
23    O1,
24    A1,
25    O2,
26    A2,
27> {
28    pub fwd: F,
29    pub rev: R,
30    pub residual: Box<ResidualFn<K, O1, A1, O2>>,
31    _phantom: std::marker::PhantomData<A2>,
32}
33
34impl<
35        F: Functor<K, O1, A1, O2, A2>,
36        R: Functor<K, O1, A1, O2, A2>,
37        K: ArrayKind,
38        O1,
39        A1,
40        O2,
41        A2,
42    > Optic<F, R, K, O1, A1, O2, A2>
43{
44    pub fn new(
45        fwd: F,
46        rev: R,
47        residual: Box<ResidualFn<K, O1, A1, O2>>,
48    ) -> Optic<F, R, K, O1, A1, O2, A2> {
49        let _phantom = std::marker::PhantomData;
50        Optic {
51            fwd,
52            rev,
53            residual,
54            _phantom,
55        }
56    }
57}
58
59impl<F, R, K: ArrayKind + Debug, O1, A1, O2, A2> Functor<K, O1, A1, O2, A2>
60    for Optic<F, R, K, O1, A1, O2, A2>
61where
62    F: Functor<K, O1, A1, O2, A2>,
63    R: Functor<K, O1, A1, O2, A2>,
64    K::Type<K::I>: NaturalArray<K>,
65    K::Type<O1>: Array<K, O1> + PartialEq,
66    K::Type<A1>: Array<K, A1>,
67    K::Type<O2>: Array<K, O2> + PartialEq + Debug,
68    K::Type<A2>: Array<K, A2>,
69{
70    fn map_object(
71        &self,
72        a: &SemifiniteFunction<K, O1>,
73    ) -> IndexedCoproduct<K, SemifiniteFunction<K, O2>> {
74        // Each object A is mapped to F(A) ● R(A)
75        let fa = self.fwd.map_object(a);
76        let ra = self.rev.map_object(a);
77
78        assert_eq!(fa.len(), ra.len());
79        let n = fa.len();
80
81        // Create paired function similar to Python's FA + RA
82        let paired = fa
83            .coproduct(&ra)
84            .expect("Coproduct of SemifiniteFunction always succeeds");
85        let p = FiniteFunction::transpose(K::I::one() + K::I::one(), n);
86
87        let sources = FiniteFunction::new(
88            fa.sources.table + ra.sources.table,
89            fa.sources.target + ra.sources.target - K::I::one(),
90        )
91        .unwrap();
92        let values = paired.indexed_values(&p).unwrap();
93        IndexedCoproduct::new(sources, values).unwrap()
94    }
95
96    fn map_operations(&self, ops: Operations<K, O1, A1>) -> OpenHypergraph<K, O2, A2> {
97        // Forward and reverse maps
98        let fwd = self.fwd.map_operations(ops.clone());
99        let rev = self.rev.map_operations(ops.clone());
100
101        // Get mapped objects
102        let fa = self.fwd.map_object(&ops.a.values);
103        let fb = self.fwd.map_object(&ops.b.values);
104        let ra = self.rev.map_object(&ops.a.values);
105        let rb = self.rev.map_object(&ops.b.values);
106
107        let m = (self.residual)(&ops);
108
109        // Create interleavings
110        let fwd_interleave = interleave_blocks(&ops.b.flatmap_sources(&fb), &m).dagger();
111        let rev_cointerleave = interleave_blocks(&m, &ops.b.flatmap_sources(&rb));
112
113        debug_assert_eq!(fwd.target(), fwd_interleave.source());
114        debug_assert_eq!(rev_cointerleave.target(), rev.source());
115
116        let i_fb = OpenHypergraph::identity(fb.values.clone());
117        let i_rb = OpenHypergraph::identity(rb.values.clone());
118
119        // Compose the diagram parts
120        let lhs = fwd.compose(&fwd_interleave).unwrap().tensor(&i_rb);
121        let rhs = i_fb.tensor(&rev_cointerleave.compose(&rev).unwrap());
122        let c = lhs.compose(&rhs).unwrap();
123
124        // Partial dagger to bend wires
125        let d = partial_dagger(&c, &fa, &fb, &ra, &rb);
126
127        // Final interleaving
128        let lhs = interleave_blocks(&fa, &ra).dagger();
129        let rhs = interleave_blocks(&fb, &rb);
130
131        lhs.compose(&d).unwrap().compose(&rhs).unwrap()
132    }
133
134    fn map_arrow(&self, f: &OpenHypergraph<K, O1, A1>) -> OpenHypergraph<K, O2, A2> {
135        define_map_arrow(self, f)
136    }
137}
138
139impl<F, R, K: ArrayKind + Debug, O1, A1, O2, A2> Optic<F, R, K, O1, A1, O2, A2>
140where
141    F: Functor<K, O1, A1, O2, A2>,
142    R: Functor<K, O1, A1, O2, A2>,
143    K::Type<K::I>: NaturalArray<K>,
144    K::Type<O1>: Array<K, O1> + PartialEq,
145    K::Type<A1>: Array<K, A1>,
146    K::Type<O2>: Array<K, O2> + PartialEq + Debug,
147    K::Type<A2>: Array<K, A2>,
148{
149    pub fn adapt(
150        &self,
151        c: &OpenHypergraph<K, O2, A2>,
152        a: &SemifiniteFunction<K, O1>,
153        b: &SemifiniteFunction<K, O1>,
154    ) -> OpenHypergraph<K, O2, A2> {
155        let fa = self.fwd.map_object(a);
156        let fb = self.fwd.map_object(b);
157        let ra = self.rev.map_object(a);
158        let rb = self.rev.map_object(b);
159
160        // Uninterleave to get d : FA●RA → FB●RB
161        let lhs = interleave_blocks(&fa, &ra);
162        let rhs = interleave_blocks(&fb, &rb).dagger();
163        let d = lhs.compose(c).unwrap().compose(&rhs).unwrap();
164
165        // Verify source/target
166        // NOTE: unwrap() because coproduct of semifinite functions always succeeds.
167        debug_assert_eq!(d.source(), fa.coproduct(&ra).unwrap().values);
168        debug_assert_eq!(d.target(), fb.coproduct(&rb).unwrap().values);
169
170        // Partial dagger to get d : FA●RB → FB●RA
171        partial_dagger(&d, &fa, &fb, &rb, &ra)
172    }
173}
174
175fn interleave_blocks<K: ArrayKind, O, A>(
176    a: &IndexedCoproduct<K, SemifiniteFunction<K, O>>,
177    b: &IndexedCoproduct<K, SemifiniteFunction<K, O>>,
178) -> OpenHypergraph<K, O, A>
179where
180    K::Type<K::I>: NaturalArray<K>,
181    K::Type<O>: Array<K, O> + PartialEq,
182    K::Type<A>: Array<K, A>,
183{
184    if a.len() != b.len() {
185        panic!("Can't interleave types of unequal lengths");
186    }
187
188    let ab = a
189        .coproduct(b)
190        .expect("Coproduct of SemifiniteFunction always succeeds");
191
192    let two = K::I::one() + K::I::one();
193    let s = FiniteFunction::identity(ab.values.len());
194    let t = ab
195        .sources
196        .injections(&FiniteFunction::transpose(two, a.len()))
197        .unwrap();
198
199    OpenHypergraph::spider(s, t, ab.values.clone()).unwrap()
200}
201
202/// Helper function to perform partial dagger operation
203fn partial_dagger<K: ArrayKind + Debug, O, A>(
204    c: &OpenHypergraph<K, O, A>,
205    fa: &IndexedCoproduct<K, SemifiniteFunction<K, O>>,
206    fb: &IndexedCoproduct<K, SemifiniteFunction<K, O>>,
207    ra: &IndexedCoproduct<K, SemifiniteFunction<K, O>>,
208    rb: &IndexedCoproduct<K, SemifiniteFunction<K, O>>,
209) -> OpenHypergraph<K, O, A>
210where
211    K::Type<K::I>: NaturalArray<K>,
212    K::Type<O>: Array<K, O>,
213    K::Type<A>: Array<K, A>,
214{
215    let s = {
216        let s_i = FiniteFunction::inj0(fa.values.len(), rb.values.len())
217            .compose(&c.s)
218            .unwrap();
219
220        let s_o = FiniteFunction::inj1(fb.values.len(), ra.values.len())
221            .compose(&c.t)
222            .unwrap();
223
224        s_i.coproduct(&s_o).unwrap()
225    };
226
227    let t = {
228        let t_i = FiniteFunction::inj0(fb.values.len(), ra.values.len())
229            .compose(&c.t)
230            .unwrap();
231        let t_o = FiniteFunction::inj1(fa.values.len(), rb.values.len())
232            .compose(&c.s)
233            .unwrap();
234
235        t_i.coproduct(&t_o).unwrap()
236    };
237
238    OpenHypergraph::new(s, t, c.h.clone()).unwrap()
239}