1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
//! A MIR pass which duplicates a coroutine's body and removes any derefs which
//! would be present for upvars that are taken by-ref. The result of which will
//! be a coroutine body that takes all of its upvars by-move, and which we stash
//! into the `CoroutineInfo` for all coroutines returned by coroutine-closures.

use rustc_data_structures::fx::FxIndexSet;
use rustc_hir as hir;
use rustc_middle::mir::visit::MutVisitor;
use rustc_middle::mir::{self, dump_mir, MirPass};
use rustc_middle::ty::{self, InstanceDef, Ty, TyCtxt, TypeVisitableExt};
use rustc_target::abi::FieldIdx;

pub struct ByMoveBody;

impl<'tcx> MirPass<'tcx> for ByMoveBody {
    fn run_pass(&self, tcx: TyCtxt<'tcx>, body: &mut mir::Body<'tcx>) {
        let Some(coroutine_def_id) = body.source.def_id().as_local() else {
            return;
        };
        let Some(hir::CoroutineKind::Desugared(_, hir::CoroutineSource::Closure)) =
            tcx.coroutine_kind(coroutine_def_id)
        else {
            return;
        };
        let coroutine_ty = body.local_decls[ty::CAPTURE_STRUCT_LOCAL].ty;
        if coroutine_ty.references_error() {
            return;
        }
        let ty::Coroutine(_, args) = *coroutine_ty.kind() else { bug!("{body:#?}") };

        let coroutine_kind = args.as_coroutine().kind_ty().to_opt_closure_kind().unwrap();
        if coroutine_kind == ty::ClosureKind::FnOnce {
            return;
        }

        let mut by_ref_fields = FxIndexSet::default();
        let by_move_upvars = Ty::new_tup_from_iter(
            tcx,
            tcx.closure_captures(coroutine_def_id).iter().enumerate().map(|(idx, capture)| {
                if capture.is_by_ref() {
                    by_ref_fields.insert(FieldIdx::from_usize(idx));
                }
                capture.place.ty()
            }),
        );
        let by_move_coroutine_ty = Ty::new_coroutine(
            tcx,
            coroutine_def_id.to_def_id(),
            ty::CoroutineArgs::new(
                tcx,
                ty::CoroutineArgsParts {
                    parent_args: args.as_coroutine().parent_args(),
                    kind_ty: Ty::from_closure_kind(tcx, ty::ClosureKind::FnOnce),
                    resume_ty: args.as_coroutine().resume_ty(),
                    yield_ty: args.as_coroutine().yield_ty(),
                    return_ty: args.as_coroutine().return_ty(),
                    witness: args.as_coroutine().witness(),
                    tupled_upvars_ty: by_move_upvars,
                },
            )
            .args,
        );

        let mut by_move_body = body.clone();
        MakeByMoveBody { tcx, by_ref_fields, by_move_coroutine_ty }.visit_body(&mut by_move_body);
        dump_mir(tcx, false, "coroutine_by_move", &0, &by_move_body, |_, _| Ok(()));
        by_move_body.source = mir::MirSource {
            instance: InstanceDef::CoroutineKindShim {
                coroutine_def_id: coroutine_def_id.to_def_id(),
                target_kind: ty::ClosureKind::FnOnce,
            },
            promoted: None,
        };
        body.coroutine.as_mut().unwrap().by_move_body = Some(by_move_body);

        // If this is coming from an `AsyncFn` coroutine-closure, we must also create a by-mut body.
        // This is actually just a copy of the by-ref body, but with a different self type.
        // FIXME(async_closures): We could probably unify this with the by-ref body somehow.
        if coroutine_kind == ty::ClosureKind::Fn {
            let by_mut_coroutine_ty = Ty::new_coroutine(
                tcx,
                coroutine_def_id.to_def_id(),
                ty::CoroutineArgs::new(
                    tcx,
                    ty::CoroutineArgsParts {
                        parent_args: args.as_coroutine().parent_args(),
                        kind_ty: Ty::from_closure_kind(tcx, ty::ClosureKind::FnMut),
                        resume_ty: args.as_coroutine().resume_ty(),
                        yield_ty: args.as_coroutine().yield_ty(),
                        return_ty: args.as_coroutine().return_ty(),
                        witness: args.as_coroutine().witness(),
                        tupled_upvars_ty: args.as_coroutine().tupled_upvars_ty(),
                    },
                )
                .args,
            );
            let mut by_mut_body = body.clone();
            by_mut_body.local_decls[ty::CAPTURE_STRUCT_LOCAL].ty = by_mut_coroutine_ty;
            dump_mir(tcx, false, "coroutine_by_mut", &0, &by_mut_body, |_, _| Ok(()));
            by_mut_body.source = mir::MirSource {
                instance: InstanceDef::CoroutineKindShim {
                    coroutine_def_id: coroutine_def_id.to_def_id(),
                    target_kind: ty::ClosureKind::FnMut,
                },
                promoted: None,
            };
            body.coroutine.as_mut().unwrap().by_mut_body = Some(by_mut_body);
        }
    }
}

struct MakeByMoveBody<'tcx> {
    tcx: TyCtxt<'tcx>,
    by_ref_fields: FxIndexSet<FieldIdx>,
    by_move_coroutine_ty: Ty<'tcx>,
}

impl<'tcx> MutVisitor<'tcx> for MakeByMoveBody<'tcx> {
    fn tcx(&self) -> TyCtxt<'tcx> {
        self.tcx
    }

    fn visit_place(
        &mut self,
        place: &mut mir::Place<'tcx>,
        context: mir::visit::PlaceContext,
        location: mir::Location,
    ) {
        if place.local == ty::CAPTURE_STRUCT_LOCAL
            && !place.projection.is_empty()
            && let mir::ProjectionElem::Field(idx, ty) = place.projection[0]
            && self.by_ref_fields.contains(&idx)
        {
            let (begin, end) = place.projection[1..].split_first().unwrap();
            // FIXME(async_closures): I'm actually a bit surprised to see that we always
            // initially deref the by-ref upvars. If this is not actually true, then we
            // will at least get an ICE that explains why this isn't true :^)
            assert_eq!(*begin, mir::ProjectionElem::Deref);
            // Peel one ref off of the ty.
            let peeled_ty = ty.builtin_deref(true).unwrap().ty;
            *place = mir::Place {
                local: place.local,
                projection: self.tcx.mk_place_elems_from_iter(
                    [mir::ProjectionElem::Field(idx, peeled_ty)]
                        .into_iter()
                        .chain(end.iter().copied()),
                ),
            };
        }
        self.super_place(place, context, location);
    }

    fn visit_local_decl(&mut self, local: mir::Local, local_decl: &mut mir::LocalDecl<'tcx>) {
        // Replace the type of the self arg.
        if local == ty::CAPTURE_STRUCT_LOCAL {
            local_decl.ty = self.by_move_coroutine_ty;
        }
    }
}