1use super::*;
4
5struct FixReturnPendingVisitor<'tcx> {
7 tcx: TyCtxt<'tcx>,
8}
9
10impl<'tcx> MutVisitor<'tcx> for FixReturnPendingVisitor<'tcx> {
11 fn tcx(&self) -> TyCtxt<'tcx> {
12 self.tcx
13 }
14
15 fn visit_assign(
16 &mut self,
17 place: &mut Place<'tcx>,
18 rvalue: &mut Rvalue<'tcx>,
19 _location: Location,
20 ) {
21 if place.local != RETURN_PLACE {
22 return;
23 }
24
25 if let Rvalue::Aggregate(kind, _) = rvalue {
27 if let AggregateKind::Adt(_, _, ref mut args, _, _) = **kind {
28 *args = self.tcx.mk_args(&[self.tcx.types.unit.into()]);
29 }
30 }
31 }
32}
33
34fn build_poll_call<'tcx>(
36 tcx: TyCtxt<'tcx>,
37 body: &mut Body<'tcx>,
38 poll_unit_place: &Place<'tcx>,
39 switch_block: BasicBlock,
40 fut_pin_place: &Place<'tcx>,
41 fut_ty: Ty<'tcx>,
42 context_ref_place: &Place<'tcx>,
43 unwind: UnwindAction,
44) -> BasicBlock {
45 let poll_fn = tcx.require_lang_item(LangItem::FuturePoll, DUMMY_SP);
46 let poll_fn = Ty::new_fn_def(tcx, poll_fn, [fut_ty]);
47 let poll_fn = Operand::Constant(Box::new(ConstOperand {
48 span: DUMMY_SP,
49 user_ty: None,
50 const_: Const::zero_sized(poll_fn),
51 }));
52 let call = TerminatorKind::Call {
53 func: poll_fn.clone(),
54 args: [
55 dummy_spanned(Operand::Move(*fut_pin_place)),
56 dummy_spanned(Operand::Move(*context_ref_place)),
57 ]
58 .into(),
59 destination: *poll_unit_place,
60 target: Some(switch_block),
61 unwind,
62 call_source: CallSource::Misc,
63 fn_span: DUMMY_SP,
64 };
65 insert_term_block(body, call)
66}
67
68fn build_pin_fut<'tcx>(
70 tcx: TyCtxt<'tcx>,
71 body: &mut Body<'tcx>,
72 fut_place: Place<'tcx>,
73 unwind: UnwindAction,
74) -> (BasicBlock, Place<'tcx>) {
75 let span = body.span;
76 let source_info = SourceInfo::outermost(span);
77 let fut_ty = fut_place.ty(&body.local_decls, tcx).ty;
78 let fut_ref_ty = Ty::new_mut_ref(tcx, tcx.lifetimes.re_erased, fut_ty);
79 let fut_ref_place = Place::from(body.local_decls.push(LocalDecl::new(fut_ref_ty, span)));
80 let pin_fut_new_unchecked_fn =
81 Ty::new_fn_def(tcx, tcx.require_lang_item(LangItem::PinNewUnchecked, span), [fut_ref_ty]);
82 let fut_pin_ty = pin_fut_new_unchecked_fn.fn_sig(tcx).output().skip_binder();
83 let fut_pin_place = Place::from(body.local_decls.push(LocalDecl::new(fut_pin_ty, span)));
84 let pin_fut_new_unchecked_fn = Operand::Constant(Box::new(ConstOperand {
85 span,
86 user_ty: None,
87 const_: Const::zero_sized(pin_fut_new_unchecked_fn),
88 }));
89
90 let storage_live =
91 Statement { source_info, kind: StatementKind::StorageLive(fut_pin_place.local) };
92
93 let fut_ref_assign = Statement {
94 source_info,
95 kind: StatementKind::Assign(Box::new((
96 fut_ref_place,
97 Rvalue::Ref(
98 tcx.lifetimes.re_erased,
99 BorrowKind::Mut { kind: MutBorrowKind::Default },
100 fut_place,
101 ),
102 ))),
103 };
104
105 let pin_fut_bb = body.basic_blocks_mut().push(BasicBlockData {
107 statements: [storage_live, fut_ref_assign].to_vec(),
108 terminator: Some(Terminator {
109 source_info,
110 kind: TerminatorKind::Call {
111 func: pin_fut_new_unchecked_fn,
112 args: [dummy_spanned(Operand::Move(fut_ref_place))].into(),
113 destination: fut_pin_place,
114 target: None, unwind,
116 call_source: CallSource::Misc,
117 fn_span: span,
118 },
119 }),
120 is_cleanup: false,
121 });
122 (pin_fut_bb, fut_pin_place)
123}
124
125fn build_poll_switch<'tcx>(
131 tcx: TyCtxt<'tcx>,
132 body: &mut Body<'tcx>,
133 poll_enum: Ty<'tcx>,
134 poll_unit_place: &Place<'tcx>,
135 ready_block: BasicBlock,
136 yield_block: BasicBlock,
137) -> BasicBlock {
138 let poll_enum_adt = poll_enum.ty_adt_def().unwrap();
139
140 let Discr { val: poll_ready_discr, ty: poll_discr_ty } = poll_enum
141 .discriminant_for_variant(
142 tcx,
143 poll_enum_adt
144 .variant_index_with_id(tcx.require_lang_item(LangItem::PollReady, DUMMY_SP)),
145 )
146 .unwrap();
147 let poll_pending_discr = poll_enum
148 .discriminant_for_variant(
149 tcx,
150 poll_enum_adt
151 .variant_index_with_id(tcx.require_lang_item(LangItem::PollPending, DUMMY_SP)),
152 )
153 .unwrap()
154 .val;
155 let source_info = SourceInfo::outermost(body.span);
156 let poll_discr_place =
157 Place::from(body.local_decls.push(LocalDecl::new(poll_discr_ty, source_info.span)));
158 let discr_assign = Statement {
159 source_info,
160 kind: StatementKind::Assign(Box::new((
161 poll_discr_place,
162 Rvalue::Discriminant(*poll_unit_place),
163 ))),
164 };
165 let unreachable_block = insert_term_block(body, TerminatorKind::Unreachable);
166 body.basic_blocks_mut().push(BasicBlockData {
167 statements: [discr_assign].to_vec(),
168 terminator: Some(Terminator {
169 source_info,
170 kind: TerminatorKind::SwitchInt {
171 discr: Operand::Move(poll_discr_place),
172 targets: SwitchTargets::new(
173 [(poll_ready_discr, ready_block), (poll_pending_discr, yield_block)]
174 .into_iter(),
175 unreachable_block,
176 ),
177 },
178 }),
179 is_cleanup: false,
180 })
181}
182
183fn gather_dropline_blocks<'tcx>(body: &mut Body<'tcx>) -> DenseBitSet<BasicBlock> {
185 let mut dropline: DenseBitSet<BasicBlock> = DenseBitSet::new_empty(body.basic_blocks.len());
186 for (bb, data) in traversal::reverse_postorder(body) {
187 if dropline.contains(bb) {
188 data.terminator().successors().for_each(|v| {
189 dropline.insert(v);
190 });
191 } else {
192 match data.terminator().kind {
193 TerminatorKind::Yield { drop: Some(v), .. } => {
194 dropline.insert(v);
195 }
196 TerminatorKind::Drop { drop: Some(v), .. } => {
197 dropline.insert(v);
198 }
199 _ => (),
200 }
201 }
202 }
203 dropline
204}
205
206pub(super) fn cleanup_async_drops<'tcx>(body: &mut Body<'tcx>) {
208 for block in body.basic_blocks_mut() {
209 if let TerminatorKind::Drop {
210 place: _,
211 target: _,
212 unwind: _,
213 replace: _,
214 ref mut drop,
215 ref mut async_fut,
216 } = block.terminator_mut().kind
217 {
218 if drop.is_some() || async_fut.is_some() {
219 *drop = None;
220 *async_fut = None;
221 }
222 }
223 }
224}
225
226pub(super) fn has_expandable_async_drops<'tcx>(
227 tcx: TyCtxt<'tcx>,
228 body: &mut Body<'tcx>,
229 coroutine_ty: Ty<'tcx>,
230) -> bool {
231 for bb in START_BLOCK..body.basic_blocks.next_index() {
232 if body[bb].is_cleanup {
234 continue;
235 }
236 let TerminatorKind::Drop { place, target: _, unwind: _, replace: _, drop: _, async_fut } =
237 body[bb].terminator().kind
238 else {
239 continue;
240 };
241 let place_ty = place.ty(&body.local_decls, tcx).ty;
242 if place_ty == coroutine_ty {
243 continue;
244 }
245 if async_fut.is_none() {
246 continue;
247 }
248 return true;
249 }
250 return false;
251}
252
253pub(super) fn expand_async_drops<'tcx>(
255 tcx: TyCtxt<'tcx>,
256 body: &mut Body<'tcx>,
257 context_mut_ref: Ty<'tcx>,
258 coroutine_kind: hir::CoroutineKind,
259 coroutine_ty: Ty<'tcx>,
260) {
261 let dropline = gather_dropline_blocks(body);
262 let remove_asyncness = |block: &mut BasicBlockData<'tcx>| {
264 if let TerminatorKind::Drop {
265 place: _,
266 target: _,
267 unwind: _,
268 replace: _,
269 ref mut drop,
270 ref mut async_fut,
271 } = block.terminator_mut().kind
272 {
273 *drop = None;
274 *async_fut = None;
275 }
276 };
277 for bb in START_BLOCK..body.basic_blocks.next_index() {
278 if body[bb].is_cleanup {
280 remove_asyncness(&mut body[bb]);
281 continue;
282 }
283 let TerminatorKind::Drop { place, target, unwind, replace: _, drop, async_fut } =
284 body[bb].terminator().kind
285 else {
286 continue;
287 };
288
289 let place_ty = place.ty(&body.local_decls, tcx).ty;
290 if place_ty == coroutine_ty {
291 remove_asyncness(&mut body[bb]);
292 continue;
293 }
294
295 let Some(fut_local) = async_fut else {
296 remove_asyncness(&mut body[bb]);
297 continue;
298 };
299
300 let is_dropline_bb = dropline.contains(bb);
301
302 if !is_dropline_bb && drop.is_none() {
303 remove_asyncness(&mut body[bb]);
304 continue;
305 }
306
307 let fut_place = Place::from(fut_local);
308 let fut_ty = fut_place.ty(&body.local_decls, tcx).ty;
309
310 let source_info = body[bb].terminator.as_ref().unwrap().source_info;
319
320 let poll_adt_ref = tcx.adt_def(tcx.require_lang_item(LangItem::Poll, source_info.span));
322 let poll_enum = Ty::new_adt(tcx, poll_adt_ref, tcx.mk_args(&[tcx.types.unit.into()]));
323 let poll_decl = LocalDecl::new(poll_enum, source_info.span);
324 let poll_unit_place = Place::from(body.local_decls.push(poll_decl));
325
326 let context_ref_place =
328 Place::from(body.local_decls.push(LocalDecl::new(context_mut_ref, source_info.span)));
329 let arg = Rvalue::Use(Operand::Move(Place::from(CTX_ARG)));
330 body[bb].statements.push(Statement {
331 source_info,
332 kind: StatementKind::Assign(Box::new((context_ref_place, arg))),
333 });
334 let yield_block = insert_term_block(body, TerminatorKind::Unreachable); let switch_block =
336 build_poll_switch(tcx, body, poll_enum, &poll_unit_place, target, yield_block);
337 let (pin_bb, fut_pin_place) =
338 build_pin_fut(tcx, body, fut_place.clone(), UnwindAction::Continue);
339 let call_bb = build_poll_call(
340 tcx,
341 body,
342 &poll_unit_place,
343 switch_block,
344 &fut_pin_place,
345 fut_ty,
346 &context_ref_place,
347 unwind,
348 );
349
350 let mut dropline_transition_bb: Option<BasicBlock> = None;
352 let mut dropline_yield_bb: Option<BasicBlock> = None;
353 let mut dropline_context_ref: Option<Place<'_>> = None;
354 let mut dropline_call_bb: Option<BasicBlock> = None;
355 if !is_dropline_bb {
356 let context_ref_place2: Place<'_> = Place::from(
357 body.local_decls.push(LocalDecl::new(context_mut_ref, source_info.span)),
358 );
359 let drop_yield_block = insert_term_block(body, TerminatorKind::Unreachable); let drop_switch_block = build_poll_switch(
361 tcx,
362 body,
363 poll_enum,
364 &poll_unit_place,
365 drop.unwrap(),
366 drop_yield_block,
367 );
368 let (pin_bb2, fut_pin_place2) =
369 build_pin_fut(tcx, body, fut_place, UnwindAction::Continue);
370 let drop_call_bb = build_poll_call(
371 tcx,
372 body,
373 &poll_unit_place,
374 drop_switch_block,
375 &fut_pin_place2,
376 fut_ty,
377 &context_ref_place2,
378 unwind,
379 );
380 dropline_transition_bb = Some(pin_bb2);
381 dropline_yield_bb = Some(drop_yield_block);
382 dropline_context_ref = Some(context_ref_place2);
383 dropline_call_bb = Some(drop_call_bb);
384 }
385
386 let value =
387 if matches!(coroutine_kind, CoroutineKind::Desugared(CoroutineDesugaring::AsyncGen, _))
388 {
389 let full_yield_ty = body.yield_ty().unwrap();
391 let ty::Adt(_poll_adt, args) = *full_yield_ty.kind() else { bug!() };
392 let ty::Adt(_option_adt, args) = *args.type_at(0).kind() else { bug!() };
393 let yield_ty = args.type_at(0);
394 Operand::Constant(Box::new(ConstOperand {
395 span: source_info.span,
396 const_: Const::Unevaluated(
397 UnevaluatedConst::new(
398 tcx.require_lang_item(LangItem::AsyncGenPending, source_info.span),
399 tcx.mk_args(&[yield_ty.into()]),
400 ),
401 full_yield_ty,
402 ),
403 user_ty: None,
404 }))
405 } else {
406 Operand::Constant(Box::new(ConstOperand {
408 span: source_info.span,
409 user_ty: None,
410 const_: Const::from_bool(tcx, false),
411 }))
412 };
413
414 use rustc_middle::mir::AssertKind::ResumedAfterDrop;
415 let panic_bb = insert_panic_block(tcx, body, ResumedAfterDrop(coroutine_kind));
416
417 if is_dropline_bb {
418 body[yield_block].terminator_mut().kind = TerminatorKind::Yield {
419 value: value.clone(),
420 resume: panic_bb,
421 resume_arg: context_ref_place,
422 drop: Some(pin_bb),
423 };
424 } else {
425 body[yield_block].terminator_mut().kind = TerminatorKind::Yield {
426 value: value.clone(),
427 resume: pin_bb,
428 resume_arg: context_ref_place,
429 drop: dropline_transition_bb,
430 };
431 body[dropline_yield_bb.unwrap()].terminator_mut().kind = TerminatorKind::Yield {
432 value,
433 resume: panic_bb,
434 resume_arg: dropline_context_ref.unwrap(),
435 drop: dropline_transition_bb,
436 };
437 }
438
439 if let TerminatorKind::Call { ref mut target, .. } = body[pin_bb].terminator_mut().kind {
440 *target = Some(call_bb);
441 } else {
442 bug!()
443 }
444 if !is_dropline_bb {
445 if let TerminatorKind::Call { ref mut target, .. } =
446 body[dropline_transition_bb.unwrap()].terminator_mut().kind
447 {
448 *target = dropline_call_bb;
449 } else {
450 bug!()
451 }
452 }
453
454 body[bb].terminator_mut().kind = TerminatorKind::Goto { target: pin_bb };
455 }
456}
457
458pub(super) fn elaborate_coroutine_drops<'tcx>(tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) {
459 use crate::elaborate_drop::{Unwind, elaborate_drop};
460 use crate::patch::MirPatch;
461 use crate::shim::DropShimElaborator;
462
463 let typing_env = body.typing_env(tcx);
467
468 let mut elaborator = DropShimElaborator {
469 body,
470 patch: MirPatch::new(body),
471 tcx,
472 typing_env,
473 produce_async_drops: false,
474 };
475
476 for (block, block_data) in body.basic_blocks.iter_enumerated() {
477 let (target, unwind, source_info, dropline) = match block_data.terminator() {
478 Terminator {
479 source_info,
480 kind: TerminatorKind::Drop { place, target, unwind, replace: _, drop, async_fut: _ },
481 } => {
482 if let Some(local) = place.as_local()
483 && local == SELF_ARG
484 {
485 (target, unwind, source_info, *drop)
486 } else {
487 continue;
488 }
489 }
490 _ => continue,
491 };
492 let unwind = if block_data.is_cleanup {
493 Unwind::InCleanup
494 } else {
495 Unwind::To(match *unwind {
496 UnwindAction::Cleanup(tgt) => tgt,
497 UnwindAction::Continue => elaborator.patch.resume_block(),
498 UnwindAction::Unreachable => elaborator.patch.unreachable_cleanup_block(),
499 UnwindAction::Terminate(reason) => elaborator.patch.terminate_block(reason),
500 })
501 };
502 elaborate_drop(
503 &mut elaborator,
504 *source_info,
505 Place::from(SELF_ARG),
506 (),
507 *target,
508 unwind,
509 block,
510 dropline,
511 );
512 }
513 elaborator.patch.apply(body);
514}
515
516pub(super) fn insert_clean_drop<'tcx>(
517 tcx: TyCtxt<'tcx>,
518 body: &mut Body<'tcx>,
519 has_async_drops: bool,
520) -> BasicBlock {
521 let source_info = SourceInfo::outermost(body.span);
522 let return_block = if has_async_drops {
523 insert_poll_ready_block(tcx, body)
524 } else {
525 insert_term_block(body, TerminatorKind::Return)
526 };
527
528 let dropline = None;
532
533 let term = TerminatorKind::Drop {
534 place: Place::from(SELF_ARG),
535 target: return_block,
536 unwind: UnwindAction::Continue,
537 replace: false,
538 drop: dropline,
539 async_fut: None,
540 };
541
542 body.basic_blocks_mut().push(BasicBlockData {
544 statements: Vec::new(),
545 terminator: Some(Terminator { source_info, kind: term }),
546 is_cleanup: false,
547 })
548}
549
550pub(super) fn create_coroutine_drop_shim<'tcx>(
551 tcx: TyCtxt<'tcx>,
552 transform: &TransformVisitor<'tcx>,
553 coroutine_ty: Ty<'tcx>,
554 body: &Body<'tcx>,
555 drop_clean: BasicBlock,
556) -> Body<'tcx> {
557 let mut body = body.clone();
558 let _ = body.coroutine.take();
561 body.arg_count = 1;
564
565 let source_info = SourceInfo::outermost(body.span);
566
567 let mut cases = create_cases(&mut body, transform, Operation::Drop);
568
569 cases.insert(0, (CoroutineArgs::UNRESUMED, drop_clean));
570
571 let default_block = insert_term_block(&mut body, TerminatorKind::Return);
575 insert_switch(&mut body, cases, transform, default_block);
576
577 for block in body.basic_blocks_mut() {
578 let kind = &mut block.terminator_mut().kind;
579 if let TerminatorKind::CoroutineDrop = *kind {
580 *kind = TerminatorKind::Return;
581 }
582 }
583
584 body.local_decls[RETURN_PLACE] = LocalDecl::with_source_info(tcx.types.unit, source_info);
586
587 make_coroutine_state_argument_indirect(tcx, &mut body);
588
589 body.local_decls[SELF_ARG] =
591 LocalDecl::with_source_info(Ty::new_mut_ptr(tcx, coroutine_ty), source_info);
592
593 simplify::remove_dead_blocks(&mut body);
596
597 let coroutine_instance = body.source.instance;
599 let drop_in_place = tcx.require_lang_item(LangItem::DropInPlace, body.span);
600 let drop_instance = InstanceKind::DropGlue(drop_in_place, Some(coroutine_ty));
601
602 body.source.instance = coroutine_instance;
605 dump_mir(tcx, false, "coroutine_drop", &0, &body, |_, _| Ok(()));
606 body.source.instance = drop_instance;
607
608 body.phase = MirPhase::Runtime(RuntimePhase::Initial);
614
615 body
616}
617
618pub(super) fn create_coroutine_drop_shim_async<'tcx>(
620 tcx: TyCtxt<'tcx>,
621 transform: &TransformVisitor<'tcx>,
622 body: &Body<'tcx>,
623 drop_clean: BasicBlock,
624 can_unwind: bool,
625) -> Body<'tcx> {
626 let mut body = body.clone();
627 let _ = body.coroutine.take();
630
631 FixReturnPendingVisitor { tcx }.visit_body(&mut body);
632
633 if can_unwind {
635 generate_poison_block_and_redirect_unwinds_there(transform, &mut body);
636 }
637
638 let source_info = SourceInfo::outermost(body.span);
639
640 let mut cases = create_cases(&mut body, transform, Operation::Drop);
641
642 cases.insert(0, (CoroutineArgs::UNRESUMED, drop_clean));
643
644 use rustc_middle::mir::AssertKind::ResumedAfterPanic;
645 if can_unwind {
647 cases.insert(
648 1,
649 (
650 CoroutineArgs::POISONED,
651 insert_panic_block(tcx, &mut body, ResumedAfterPanic(transform.coroutine_kind)),
652 ),
653 );
654 }
655
656 let default_block = insert_poll_ready_block(tcx, &mut body);
659 insert_switch(&mut body, cases, transform, default_block);
660
661 for block in body.basic_blocks_mut() {
662 let kind = &mut block.terminator_mut().kind;
663 if let TerminatorKind::CoroutineDrop = *kind {
664 *kind = TerminatorKind::Return;
665 block.statements.push(return_poll_ready_assign(tcx, source_info));
666 }
667 }
668
669 let poll_adt_ref = tcx.adt_def(tcx.require_lang_item(LangItem::Poll, body.span));
671 let poll_enum = Ty::new_adt(tcx, poll_adt_ref, tcx.mk_args(&[tcx.types.unit.into()]));
672 body.local_decls[RETURN_PLACE] = LocalDecl::with_source_info(poll_enum, source_info);
673
674 make_coroutine_state_argument_indirect(tcx, &mut body);
675
676 match transform.coroutine_kind {
677 CoroutineKind::Desugared(CoroutineDesugaring::Gen, _) => {}
680 _ => {
681 make_coroutine_state_argument_pinned(tcx, &mut body);
682 }
683 }
684
685 simplify::remove_dead_blocks(&mut body);
688
689 pm::run_passes_no_validate(
690 tcx,
691 &mut body,
692 &[&abort_unwinding_calls::AbortUnwindingCalls],
693 None,
694 );
695
696 dump_mir(tcx, false, "coroutine_drop_async", &0, &body, |_, _| Ok(()));
697
698 body
699}
700
701pub(super) fn create_coroutine_drop_shim_proxy_async<'tcx>(
704 tcx: TyCtxt<'tcx>,
705 body: &Body<'tcx>,
706) -> Body<'tcx> {
707 let mut body = body.clone();
708 let _ = body.coroutine.take();
711 let basic_blocks: IndexVec<BasicBlock, BasicBlockData<'tcx>> = IndexVec::new();
712 body.basic_blocks = BasicBlocks::new(basic_blocks);
713 body.var_debug_info.clear();
714
715 body.local_decls.truncate(1 + body.arg_count);
717
718 let source_info = SourceInfo::outermost(body.span);
719
720 let poll_adt_ref = tcx.adt_def(tcx.require_lang_item(LangItem::Poll, body.span));
722 let poll_enum = Ty::new_adt(tcx, poll_adt_ref, tcx.mk_args(&[tcx.types.unit.into()]));
723 body.local_decls[RETURN_PLACE] = LocalDecl::with_source_info(poll_enum, source_info);
724
725 let call_bb = body.basic_blocks_mut().push(BasicBlockData {
727 statements: Vec::new(),
728 terminator: None,
729 is_cleanup: false,
730 });
731
732 let ret_bb = insert_poll_ready_block(tcx, &mut body);
734
735 let kind = TerminatorKind::Drop {
736 place: Place::from(SELF_ARG),
737 target: ret_bb,
738 unwind: UnwindAction::Continue,
739 replace: false,
740 drop: None,
741 async_fut: None,
742 };
743 body.basic_blocks_mut()[call_bb].terminator = Some(Terminator { source_info, kind });
744
745 dump_mir(tcx, false, "coroutine_drop_proxy_async", &0, &body, |_, _| Ok(()));
746
747 body
748}