1use super::*;
4
5struct FixReturnPendingVisitor<'tcx> {
7 tcx: TyCtxt<'tcx>,
8}
9
10impl<'tcx> MutVisitor<'tcx> for FixReturnPendingVisitor<'tcx> {
11 fn tcx(&self) -> TyCtxt<'tcx> {
12 self.tcx
13 }
14
15 fn visit_assign(
16 &mut self,
17 place: &mut Place<'tcx>,
18 rvalue: &mut Rvalue<'tcx>,
19 _location: Location,
20 ) {
21 if place.local != RETURN_PLACE {
22 return;
23 }
24
25 if let Rvalue::Aggregate(kind, _) = rvalue {
27 if let AggregateKind::Adt(_, _, ref mut args, _, _) = **kind {
28 *args = self.tcx.mk_args(&[self.tcx.types.unit.into()]);
29 }
30 }
31 }
32}
33
34fn build_poll_call<'tcx>(
36 tcx: TyCtxt<'tcx>,
37 body: &mut Body<'tcx>,
38 poll_unit_place: &Place<'tcx>,
39 switch_block: BasicBlock,
40 fut_pin_place: &Place<'tcx>,
41 fut_ty: Ty<'tcx>,
42 context_ref_place: &Place<'tcx>,
43 unwind: UnwindAction,
44) -> BasicBlock {
45 let poll_fn = tcx.require_lang_item(LangItem::FuturePoll, DUMMY_SP);
46 let poll_fn = Ty::new_fn_def(tcx, poll_fn, [fut_ty]);
47 let poll_fn = Operand::Constant(Box::new(ConstOperand {
48 span: DUMMY_SP,
49 user_ty: None,
50 const_: Const::zero_sized(poll_fn),
51 }));
52 let call = TerminatorKind::Call {
53 func: poll_fn.clone(),
54 args: [
55 dummy_spanned(Operand::Move(*fut_pin_place)),
56 dummy_spanned(Operand::Move(*context_ref_place)),
57 ]
58 .into(),
59 destination: *poll_unit_place,
60 target: Some(switch_block),
61 unwind,
62 call_source: CallSource::Misc,
63 fn_span: DUMMY_SP,
64 };
65 insert_term_block(body, call)
66}
67
68fn build_pin_fut<'tcx>(
70 tcx: TyCtxt<'tcx>,
71 body: &mut Body<'tcx>,
72 fut_place: Place<'tcx>,
73 unwind: UnwindAction,
74) -> (BasicBlock, Place<'tcx>) {
75 let span = body.span;
76 let source_info = SourceInfo::outermost(span);
77 let fut_ty = fut_place.ty(&body.local_decls, tcx).ty;
78 let fut_ref_ty = Ty::new_mut_ref(tcx, tcx.lifetimes.re_erased, fut_ty);
79 let fut_ref_place = Place::from(body.local_decls.push(LocalDecl::new(fut_ref_ty, span)));
80 let pin_fut_new_unchecked_fn =
81 Ty::new_fn_def(tcx, tcx.require_lang_item(LangItem::PinNewUnchecked, span), [fut_ref_ty]);
82 let fut_pin_ty = pin_fut_new_unchecked_fn.fn_sig(tcx).output().skip_binder();
83 let fut_pin_place = Place::from(body.local_decls.push(LocalDecl::new(fut_pin_ty, span)));
84 let pin_fut_new_unchecked_fn = Operand::Constant(Box::new(ConstOperand {
85 span,
86 user_ty: None,
87 const_: Const::zero_sized(pin_fut_new_unchecked_fn),
88 }));
89
90 let storage_live =
91 Statement { source_info, kind: StatementKind::StorageLive(fut_pin_place.local) };
92
93 let fut_ref_assign = Statement {
94 source_info,
95 kind: StatementKind::Assign(Box::new((
96 fut_ref_place,
97 Rvalue::Ref(
98 tcx.lifetimes.re_erased,
99 BorrowKind::Mut { kind: MutBorrowKind::Default },
100 fut_place,
101 ),
102 ))),
103 };
104
105 let pin_fut_bb = body.basic_blocks_mut().push(BasicBlockData {
107 statements: [storage_live, fut_ref_assign].to_vec(),
108 terminator: Some(Terminator {
109 source_info,
110 kind: TerminatorKind::Call {
111 func: pin_fut_new_unchecked_fn,
112 args: [dummy_spanned(Operand::Move(fut_ref_place))].into(),
113 destination: fut_pin_place,
114 target: None, unwind,
116 call_source: CallSource::Misc,
117 fn_span: span,
118 },
119 }),
120 is_cleanup: false,
121 });
122 (pin_fut_bb, fut_pin_place)
123}
124
125fn build_poll_switch<'tcx>(
131 tcx: TyCtxt<'tcx>,
132 body: &mut Body<'tcx>,
133 poll_enum: Ty<'tcx>,
134 poll_unit_place: &Place<'tcx>,
135 fut_pin_place: &Place<'tcx>,
136 ready_block: BasicBlock,
137 yield_block: BasicBlock,
138) -> BasicBlock {
139 let poll_enum_adt = poll_enum.ty_adt_def().unwrap();
140
141 let Discr { val: poll_ready_discr, ty: poll_discr_ty } = poll_enum
142 .discriminant_for_variant(
143 tcx,
144 poll_enum_adt
145 .variant_index_with_id(tcx.require_lang_item(LangItem::PollReady, DUMMY_SP)),
146 )
147 .unwrap();
148 let poll_pending_discr = poll_enum
149 .discriminant_for_variant(
150 tcx,
151 poll_enum_adt
152 .variant_index_with_id(tcx.require_lang_item(LangItem::PollPending, DUMMY_SP)),
153 )
154 .unwrap()
155 .val;
156 let source_info = SourceInfo::outermost(body.span);
157 let poll_discr_place =
158 Place::from(body.local_decls.push(LocalDecl::new(poll_discr_ty, source_info.span)));
159 let discr_assign = Statement {
160 source_info,
161 kind: StatementKind::Assign(Box::new((
162 poll_discr_place,
163 Rvalue::Discriminant(*poll_unit_place),
164 ))),
165 };
166 let storage_dead =
167 Statement { source_info, kind: StatementKind::StorageDead(fut_pin_place.local) };
168 let unreachable_block = insert_term_block(body, TerminatorKind::Unreachable);
169 body.basic_blocks_mut().push(BasicBlockData {
170 statements: [storage_dead, discr_assign].to_vec(),
171 terminator: Some(Terminator {
172 source_info,
173 kind: TerminatorKind::SwitchInt {
174 discr: Operand::Move(poll_discr_place),
175 targets: SwitchTargets::new(
176 [(poll_ready_discr, ready_block), (poll_pending_discr, yield_block)]
177 .into_iter(),
178 unreachable_block,
179 ),
180 },
181 }),
182 is_cleanup: false,
183 })
184}
185
186fn gather_dropline_blocks<'tcx>(body: &mut Body<'tcx>) -> DenseBitSet<BasicBlock> {
188 let mut dropline: DenseBitSet<BasicBlock> = DenseBitSet::new_empty(body.basic_blocks.len());
189 for (bb, data) in traversal::reverse_postorder(body) {
190 if dropline.contains(bb) {
191 data.terminator().successors().for_each(|v| {
192 dropline.insert(v);
193 });
194 } else {
195 match data.terminator().kind {
196 TerminatorKind::Yield { drop: Some(v), .. } => {
197 dropline.insert(v);
198 }
199 TerminatorKind::Drop { drop: Some(v), .. } => {
200 dropline.insert(v);
201 }
202 _ => (),
203 }
204 }
205 }
206 dropline
207}
208
209pub(super) fn cleanup_async_drops<'tcx>(body: &mut Body<'tcx>) {
211 for block in body.basic_blocks_mut() {
212 if let TerminatorKind::Drop {
213 place: _,
214 target: _,
215 unwind: _,
216 replace: _,
217 ref mut drop,
218 ref mut async_fut,
219 } = block.terminator_mut().kind
220 {
221 if drop.is_some() || async_fut.is_some() {
222 *drop = None;
223 *async_fut = None;
224 }
225 }
226 }
227}
228
229pub(super) fn has_expandable_async_drops<'tcx>(
230 tcx: TyCtxt<'tcx>,
231 body: &mut Body<'tcx>,
232 coroutine_ty: Ty<'tcx>,
233) -> bool {
234 for bb in START_BLOCK..body.basic_blocks.next_index() {
235 if body[bb].is_cleanup {
237 continue;
238 }
239 let TerminatorKind::Drop { place, target: _, unwind: _, replace: _, drop: _, async_fut } =
240 body[bb].terminator().kind
241 else {
242 continue;
243 };
244 let place_ty = place.ty(&body.local_decls, tcx).ty;
245 if place_ty == coroutine_ty {
246 continue;
247 }
248 if async_fut.is_none() {
249 continue;
250 }
251 return true;
252 }
253 return false;
254}
255
256pub(super) fn expand_async_drops<'tcx>(
258 tcx: TyCtxt<'tcx>,
259 body: &mut Body<'tcx>,
260 context_mut_ref: Ty<'tcx>,
261 coroutine_kind: hir::CoroutineKind,
262 coroutine_ty: Ty<'tcx>,
263) {
264 let dropline = gather_dropline_blocks(body);
265 let remove_asyncness = |block: &mut BasicBlockData<'tcx>| {
267 if let TerminatorKind::Drop {
268 place: _,
269 target: _,
270 unwind: _,
271 replace: _,
272 ref mut drop,
273 ref mut async_fut,
274 } = block.terminator_mut().kind
275 {
276 *drop = None;
277 *async_fut = None;
278 }
279 };
280 for bb in START_BLOCK..body.basic_blocks.next_index() {
281 if body[bb].is_cleanup {
283 remove_asyncness(&mut body[bb]);
284 continue;
285 }
286 let TerminatorKind::Drop { place, target, unwind, replace: _, drop, async_fut } =
287 body[bb].terminator().kind
288 else {
289 continue;
290 };
291
292 let place_ty = place.ty(&body.local_decls, tcx).ty;
293 if place_ty == coroutine_ty {
294 remove_asyncness(&mut body[bb]);
295 continue;
296 }
297
298 let Some(fut_local) = async_fut else {
299 remove_asyncness(&mut body[bb]);
300 continue;
301 };
302
303 let is_dropline_bb = dropline.contains(bb);
304
305 if !is_dropline_bb && drop.is_none() {
306 remove_asyncness(&mut body[bb]);
307 continue;
308 }
309
310 let fut_place = Place::from(fut_local);
311 let fut_ty = fut_place.ty(&body.local_decls, tcx).ty;
312
313 let source_info = body[bb].terminator.as_ref().unwrap().source_info;
322
323 let poll_adt_ref = tcx.adt_def(tcx.require_lang_item(LangItem::Poll, source_info.span));
325 let poll_enum = Ty::new_adt(tcx, poll_adt_ref, tcx.mk_args(&[tcx.types.unit.into()]));
326 let poll_decl = LocalDecl::new(poll_enum, source_info.span);
327 let poll_unit_place = Place::from(body.local_decls.push(poll_decl));
328
329 let context_ref_place =
331 Place::from(body.local_decls.push(LocalDecl::new(context_mut_ref, source_info.span)));
332 let arg = Rvalue::Use(Operand::Move(Place::from(CTX_ARG)));
333 body[bb].statements.push(Statement {
334 source_info,
335 kind: StatementKind::Assign(Box::new((context_ref_place, arg))),
336 });
337 let yield_block = insert_term_block(body, TerminatorKind::Unreachable); let (pin_bb, fut_pin_place) =
339 build_pin_fut(tcx, body, fut_place.clone(), UnwindAction::Continue);
340 let switch_block = build_poll_switch(
341 tcx,
342 body,
343 poll_enum,
344 &poll_unit_place,
345 &fut_pin_place,
346 target,
347 yield_block,
348 );
349 let call_bb = build_poll_call(
350 tcx,
351 body,
352 &poll_unit_place,
353 switch_block,
354 &fut_pin_place,
355 fut_ty,
356 &context_ref_place,
357 unwind,
358 );
359
360 let mut dropline_transition_bb: Option<BasicBlock> = None;
362 let mut dropline_yield_bb: Option<BasicBlock> = None;
363 let mut dropline_context_ref: Option<Place<'_>> = None;
364 let mut dropline_call_bb: Option<BasicBlock> = None;
365 if !is_dropline_bb {
366 let context_ref_place2: Place<'_> = Place::from(
367 body.local_decls.push(LocalDecl::new(context_mut_ref, source_info.span)),
368 );
369 let drop_yield_block = insert_term_block(body, TerminatorKind::Unreachable); let (pin_bb2, fut_pin_place2) =
371 build_pin_fut(tcx, body, fut_place, UnwindAction::Continue);
372 let drop_switch_block = build_poll_switch(
373 tcx,
374 body,
375 poll_enum,
376 &poll_unit_place,
377 &fut_pin_place2,
378 drop.unwrap(),
379 drop_yield_block,
380 );
381 let drop_call_bb = build_poll_call(
382 tcx,
383 body,
384 &poll_unit_place,
385 drop_switch_block,
386 &fut_pin_place2,
387 fut_ty,
388 &context_ref_place2,
389 unwind,
390 );
391 dropline_transition_bb = Some(pin_bb2);
392 dropline_yield_bb = Some(drop_yield_block);
393 dropline_context_ref = Some(context_ref_place2);
394 dropline_call_bb = Some(drop_call_bb);
395 }
396
397 let value =
398 if matches!(coroutine_kind, CoroutineKind::Desugared(CoroutineDesugaring::AsyncGen, _))
399 {
400 let full_yield_ty = body.yield_ty().unwrap();
402 let ty::Adt(_poll_adt, args) = *full_yield_ty.kind() else { bug!() };
403 let ty::Adt(_option_adt, args) = *args.type_at(0).kind() else { bug!() };
404 let yield_ty = args.type_at(0);
405 Operand::Constant(Box::new(ConstOperand {
406 span: source_info.span,
407 const_: Const::Unevaluated(
408 UnevaluatedConst::new(
409 tcx.require_lang_item(LangItem::AsyncGenPending, source_info.span),
410 tcx.mk_args(&[yield_ty.into()]),
411 ),
412 full_yield_ty,
413 ),
414 user_ty: None,
415 }))
416 } else {
417 Operand::Constant(Box::new(ConstOperand {
419 span: source_info.span,
420 user_ty: None,
421 const_: Const::from_bool(tcx, false),
422 }))
423 };
424
425 use rustc_middle::mir::AssertKind::ResumedAfterDrop;
426 let panic_bb = insert_panic_block(tcx, body, ResumedAfterDrop(coroutine_kind));
427
428 if is_dropline_bb {
429 body[yield_block].terminator_mut().kind = TerminatorKind::Yield {
430 value: value.clone(),
431 resume: panic_bb,
432 resume_arg: context_ref_place,
433 drop: Some(pin_bb),
434 };
435 } else {
436 body[yield_block].terminator_mut().kind = TerminatorKind::Yield {
437 value: value.clone(),
438 resume: pin_bb,
439 resume_arg: context_ref_place,
440 drop: dropline_transition_bb,
441 };
442 body[dropline_yield_bb.unwrap()].terminator_mut().kind = TerminatorKind::Yield {
443 value,
444 resume: panic_bb,
445 resume_arg: dropline_context_ref.unwrap(),
446 drop: dropline_transition_bb,
447 };
448 }
449
450 if let TerminatorKind::Call { ref mut target, .. } = body[pin_bb].terminator_mut().kind {
451 *target = Some(call_bb);
452 } else {
453 bug!()
454 }
455 if !is_dropline_bb {
456 if let TerminatorKind::Call { ref mut target, .. } =
457 body[dropline_transition_bb.unwrap()].terminator_mut().kind
458 {
459 *target = dropline_call_bb;
460 } else {
461 bug!()
462 }
463 }
464
465 body[bb].terminator_mut().kind = TerminatorKind::Goto { target: pin_bb };
466 }
467}
468
469pub(super) fn elaborate_coroutine_drops<'tcx>(tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) {
470 use crate::elaborate_drop::{Unwind, elaborate_drop};
471 use crate::patch::MirPatch;
472 use crate::shim::DropShimElaborator;
473
474 let typing_env = body.typing_env(tcx);
478
479 let mut elaborator = DropShimElaborator {
480 body,
481 patch: MirPatch::new(body),
482 tcx,
483 typing_env,
484 produce_async_drops: false,
485 };
486
487 for (block, block_data) in body.basic_blocks.iter_enumerated() {
488 let (target, unwind, source_info, dropline) = match block_data.terminator() {
489 Terminator {
490 source_info,
491 kind: TerminatorKind::Drop { place, target, unwind, replace: _, drop, async_fut: _ },
492 } => {
493 if let Some(local) = place.as_local()
494 && local == SELF_ARG
495 {
496 (target, unwind, source_info, *drop)
497 } else {
498 continue;
499 }
500 }
501 _ => continue,
502 };
503 let unwind = if block_data.is_cleanup {
504 Unwind::InCleanup
505 } else {
506 Unwind::To(match *unwind {
507 UnwindAction::Cleanup(tgt) => tgt,
508 UnwindAction::Continue => elaborator.patch.resume_block(),
509 UnwindAction::Unreachable => elaborator.patch.unreachable_cleanup_block(),
510 UnwindAction::Terminate(reason) => elaborator.patch.terminate_block(reason),
511 })
512 };
513 elaborate_drop(
514 &mut elaborator,
515 *source_info,
516 Place::from(SELF_ARG),
517 (),
518 *target,
519 unwind,
520 block,
521 dropline,
522 );
523 }
524 elaborator.patch.apply(body);
525}
526
527pub(super) fn insert_clean_drop<'tcx>(
528 tcx: TyCtxt<'tcx>,
529 body: &mut Body<'tcx>,
530 has_async_drops: bool,
531) -> BasicBlock {
532 let source_info = SourceInfo::outermost(body.span);
533 let return_block = if has_async_drops {
534 insert_poll_ready_block(tcx, body)
535 } else {
536 insert_term_block(body, TerminatorKind::Return)
537 };
538
539 let dropline = None;
543
544 let term = TerminatorKind::Drop {
545 place: Place::from(SELF_ARG),
546 target: return_block,
547 unwind: UnwindAction::Continue,
548 replace: false,
549 drop: dropline,
550 async_fut: None,
551 };
552
553 body.basic_blocks_mut().push(BasicBlockData {
555 statements: Vec::new(),
556 terminator: Some(Terminator { source_info, kind: term }),
557 is_cleanup: false,
558 })
559}
560
561pub(super) fn create_coroutine_drop_shim<'tcx>(
562 tcx: TyCtxt<'tcx>,
563 transform: &TransformVisitor<'tcx>,
564 coroutine_ty: Ty<'tcx>,
565 body: &Body<'tcx>,
566 drop_clean: BasicBlock,
567) -> Body<'tcx> {
568 let mut body = body.clone();
569 let _ = body.coroutine.take();
572 body.arg_count = 1;
575
576 let source_info = SourceInfo::outermost(body.span);
577
578 let mut cases = create_cases(&mut body, transform, Operation::Drop);
579
580 cases.insert(0, (CoroutineArgs::UNRESUMED, drop_clean));
581
582 let default_block = insert_term_block(&mut body, TerminatorKind::Return);
586 insert_switch(&mut body, cases, transform, default_block);
587
588 for block in body.basic_blocks_mut() {
589 let kind = &mut block.terminator_mut().kind;
590 if let TerminatorKind::CoroutineDrop = *kind {
591 *kind = TerminatorKind::Return;
592 }
593 }
594
595 body.local_decls[RETURN_PLACE] = LocalDecl::with_source_info(tcx.types.unit, source_info);
597
598 make_coroutine_state_argument_indirect(tcx, &mut body);
599
600 body.local_decls[SELF_ARG] =
602 LocalDecl::with_source_info(Ty::new_mut_ptr(tcx, coroutine_ty), source_info);
603
604 simplify::remove_dead_blocks(&mut body);
607
608 let coroutine_instance = body.source.instance;
610 let drop_in_place = tcx.require_lang_item(LangItem::DropInPlace, body.span);
611 let drop_instance = InstanceKind::DropGlue(drop_in_place, Some(coroutine_ty));
612
613 body.source.instance = coroutine_instance;
616 dump_mir(tcx, false, "coroutine_drop", &0, &body, |_, _| Ok(()));
617 body.source.instance = drop_instance;
618
619 body.phase = MirPhase::Runtime(RuntimePhase::Initial);
625
626 body
627}
628
629pub(super) fn create_coroutine_drop_shim_async<'tcx>(
631 tcx: TyCtxt<'tcx>,
632 transform: &TransformVisitor<'tcx>,
633 body: &Body<'tcx>,
634 drop_clean: BasicBlock,
635 can_unwind: bool,
636) -> Body<'tcx> {
637 let mut body = body.clone();
638 let _ = body.coroutine.take();
641
642 FixReturnPendingVisitor { tcx }.visit_body(&mut body);
643
644 if can_unwind {
646 generate_poison_block_and_redirect_unwinds_there(transform, &mut body);
647 }
648
649 let source_info = SourceInfo::outermost(body.span);
650
651 let mut cases = create_cases(&mut body, transform, Operation::Drop);
652
653 cases.insert(0, (CoroutineArgs::UNRESUMED, drop_clean));
654
655 use rustc_middle::mir::AssertKind::ResumedAfterPanic;
656 if can_unwind {
658 cases.insert(
659 1,
660 (
661 CoroutineArgs::POISONED,
662 insert_panic_block(tcx, &mut body, ResumedAfterPanic(transform.coroutine_kind)),
663 ),
664 );
665 }
666
667 let default_block = insert_poll_ready_block(tcx, &mut body);
670 insert_switch(&mut body, cases, transform, default_block);
671
672 for block in body.basic_blocks_mut() {
673 let kind = &mut block.terminator_mut().kind;
674 if let TerminatorKind::CoroutineDrop = *kind {
675 *kind = TerminatorKind::Return;
676 block.statements.push(return_poll_ready_assign(tcx, source_info));
677 }
678 }
679
680 let poll_adt_ref = tcx.adt_def(tcx.require_lang_item(LangItem::Poll, body.span));
682 let poll_enum = Ty::new_adt(tcx, poll_adt_ref, tcx.mk_args(&[tcx.types.unit.into()]));
683 body.local_decls[RETURN_PLACE] = LocalDecl::with_source_info(poll_enum, source_info);
684
685 make_coroutine_state_argument_indirect(tcx, &mut body);
686
687 match transform.coroutine_kind {
688 CoroutineKind::Desugared(CoroutineDesugaring::Gen, _) => {}
691 _ => {
692 make_coroutine_state_argument_pinned(tcx, &mut body);
693 }
694 }
695
696 simplify::remove_dead_blocks(&mut body);
699
700 pm::run_passes_no_validate(
701 tcx,
702 &mut body,
703 &[&abort_unwinding_calls::AbortUnwindingCalls],
704 None,
705 );
706
707 dump_mir(tcx, false, "coroutine_drop_async", &0, &body, |_, _| Ok(()));
708
709 body
710}
711
712pub(super) fn create_coroutine_drop_shim_proxy_async<'tcx>(
715 tcx: TyCtxt<'tcx>,
716 body: &Body<'tcx>,
717) -> Body<'tcx> {
718 let mut body = body.clone();
719 let _ = body.coroutine.take();
722 let basic_blocks: IndexVec<BasicBlock, BasicBlockData<'tcx>> = IndexVec::new();
723 body.basic_blocks = BasicBlocks::new(basic_blocks);
724 body.var_debug_info.clear();
725
726 body.local_decls.truncate(1 + body.arg_count);
728
729 let source_info = SourceInfo::outermost(body.span);
730
731 let poll_adt_ref = tcx.adt_def(tcx.require_lang_item(LangItem::Poll, body.span));
733 let poll_enum = Ty::new_adt(tcx, poll_adt_ref, tcx.mk_args(&[tcx.types.unit.into()]));
734 body.local_decls[RETURN_PLACE] = LocalDecl::with_source_info(poll_enum, source_info);
735
736 let call_bb = body.basic_blocks_mut().push(BasicBlockData {
738 statements: Vec::new(),
739 terminator: None,
740 is_cleanup: false,
741 });
742
743 let ret_bb = insert_poll_ready_block(tcx, &mut body);
745
746 let kind = TerminatorKind::Drop {
747 place: Place::from(SELF_ARG),
748 target: ret_bb,
749 unwind: UnwindAction::Continue,
750 replace: false,
751 drop: None,
752 async_fut: None,
753 };
754 body.basic_blocks_mut()[call_bb].terminator = Some(Terminator { source_info, kind });
755
756 dump_mir(tcx, false, "coroutine_drop_proxy_async", &0, &body, |_, _| Ok(()));
757
758 body
759}