miri/shims/x86/
avx.rs

1use rustc_abi::CanonAbi;
2use rustc_apfloat::ieee::{Double, Single};
3use rustc_middle::mir;
4use rustc_middle::ty::Ty;
5use rustc_middle::ty::layout::LayoutOf as _;
6use rustc_span::Symbol;
7use rustc_target::callconv::FnAbi;
8
9use super::{
10    FloatBinOp, FloatUnaryOp, bin_op_simd_float_all, conditional_dot_product, convert_float_to_int,
11    horizontal_bin_op, mask_load, mask_store, round_all, test_bits_masked, test_high_bits_masked,
12    unary_op_ps,
13};
14use crate::*;
15
16impl<'tcx> EvalContextExt<'tcx> for crate::MiriInterpCx<'tcx> {}
17pub(super) trait EvalContextExt<'tcx>: crate::MiriInterpCxExt<'tcx> {
18    fn emulate_x86_avx_intrinsic(
19        &mut self,
20        link_name: Symbol,
21        abi: &FnAbi<'tcx, Ty<'tcx>>,
22        args: &[OpTy<'tcx>],
23        dest: &MPlaceTy<'tcx>,
24    ) -> InterpResult<'tcx, EmulateItemResult> {
25        let this = self.eval_context_mut();
26        this.expect_target_feature_for_intrinsic(link_name, "avx")?;
27        // Prefix should have already been checked.
28        let unprefixed_name = link_name.as_str().strip_prefix("llvm.x86.avx.").unwrap();
29
30        match unprefixed_name {
31            // Used to implement _mm256_min_ps and _mm256_max_ps functions.
32            // Note that the semantics are a bit different from Rust simd_min
33            // and simd_max intrinsics regarding handling of NaN and -0.0: Rust
34            // matches the IEEE min/max operations, while x86 has different
35            // semantics.
36            "min.ps.256" | "max.ps.256" => {
37                let [left, right] = this.check_shim(abi, CanonAbi::C, link_name, args)?;
38
39                let which = match unprefixed_name {
40                    "min.ps.256" => FloatBinOp::Min,
41                    "max.ps.256" => FloatBinOp::Max,
42                    _ => unreachable!(),
43                };
44
45                bin_op_simd_float_all::<Single>(this, which, left, right, dest)?;
46            }
47            // Used to implement _mm256_min_pd and _mm256_max_pd functions.
48            "min.pd.256" | "max.pd.256" => {
49                let [left, right] = this.check_shim(abi, CanonAbi::C, link_name, args)?;
50
51                let which = match unprefixed_name {
52                    "min.pd.256" => FloatBinOp::Min,
53                    "max.pd.256" => FloatBinOp::Max,
54                    _ => unreachable!(),
55                };
56
57                bin_op_simd_float_all::<Double>(this, which, left, right, dest)?;
58            }
59            // Used to implement the _mm256_round_ps function.
60            // Rounds the elements of `op` according to `rounding`.
61            "round.ps.256" => {
62                let [op, rounding] = this.check_shim(abi, CanonAbi::C, link_name, args)?;
63
64                round_all::<rustc_apfloat::ieee::Single>(this, op, rounding, dest)?;
65            }
66            // Used to implement the _mm256_round_pd function.
67            // Rounds the elements of `op` according to `rounding`.
68            "round.pd.256" => {
69                let [op, rounding] = this.check_shim(abi, CanonAbi::C, link_name, args)?;
70
71                round_all::<rustc_apfloat::ieee::Double>(this, op, rounding, dest)?;
72            }
73            // Used to implement _mm256_{rcp,rsqrt}_ps functions.
74            // Performs the operations on all components of `op`.
75            "rcp.ps.256" | "rsqrt.ps.256" => {
76                let [op] = this.check_shim(abi, CanonAbi::C, link_name, args)?;
77
78                let which = match unprefixed_name {
79                    "rcp.ps.256" => FloatUnaryOp::Rcp,
80                    "rsqrt.ps.256" => FloatUnaryOp::Rsqrt,
81                    _ => unreachable!(),
82                };
83
84                unary_op_ps(this, which, op, dest)?;
85            }
86            // Used to implement the _mm256_dp_ps function.
87            "dp.ps.256" => {
88                let [left, right, imm] = this.check_shim(abi, CanonAbi::C, link_name, args)?;
89
90                conditional_dot_product(this, left, right, imm, dest)?;
91            }
92            // Used to implement the _mm256_h{add,sub}_p{s,d} functions.
93            // Horizontally add/subtract adjacent floating point values
94            // in `left` and `right`.
95            "hadd.ps.256" | "hadd.pd.256" | "hsub.ps.256" | "hsub.pd.256" => {
96                let [left, right] = this.check_shim(abi, CanonAbi::C, link_name, args)?;
97
98                let which = match unprefixed_name {
99                    "hadd.ps.256" | "hadd.pd.256" => mir::BinOp::Add,
100                    "hsub.ps.256" | "hsub.pd.256" => mir::BinOp::Sub,
101                    _ => unreachable!(),
102                };
103
104                horizontal_bin_op(this, which, /*saturating*/ false, left, right, dest)?;
105            }
106            // Used to implement the _mm256_cmp_ps function.
107            // Performs a comparison operation on each component of `left`
108            // and `right`. For each component, returns 0 if false or u32::MAX
109            // if true.
110            "cmp.ps.256" => {
111                let [left, right, imm] = this.check_shim(abi, CanonAbi::C, link_name, args)?;
112
113                let which =
114                    FloatBinOp::cmp_from_imm(this, this.read_scalar(imm)?.to_i8()?, link_name)?;
115
116                bin_op_simd_float_all::<Single>(this, which, left, right, dest)?;
117            }
118            // Used to implement the _mm256_cmp_pd function.
119            // Performs a comparison operation on each component of `left`
120            // and `right`. For each component, returns 0 if false or u64::MAX
121            // if true.
122            "cmp.pd.256" => {
123                let [left, right, imm] = this.check_shim(abi, CanonAbi::C, link_name, args)?;
124
125                let which =
126                    FloatBinOp::cmp_from_imm(this, this.read_scalar(imm)?.to_i8()?, link_name)?;
127
128                bin_op_simd_float_all::<Double>(this, which, left, right, dest)?;
129            }
130            // Used to implement the _mm256_cvtps_epi32, _mm256_cvttps_epi32, _mm256_cvtpd_epi32
131            // and _mm256_cvttpd_epi32 functions.
132            // Converts packed f32/f64 to packed i32.
133            "cvt.ps2dq.256" | "cvtt.ps2dq.256" | "cvt.pd2dq.256" | "cvtt.pd2dq.256" => {
134                let [op] = this.check_shim(abi, CanonAbi::C, link_name, args)?;
135
136                let rnd = match unprefixed_name {
137                    // "current SSE rounding mode", assume nearest
138                    "cvt.ps2dq.256" | "cvt.pd2dq.256" => rustc_apfloat::Round::NearestTiesToEven,
139                    // always truncate
140                    "cvtt.ps2dq.256" | "cvtt.pd2dq.256" => rustc_apfloat::Round::TowardZero,
141                    _ => unreachable!(),
142                };
143
144                convert_float_to_int(this, op, rnd, dest)?;
145            }
146            // Used to implement the _mm_permutevar_ps and _mm256_permutevar_ps functions.
147            // Shuffles 32-bit floats from `data` using `control` as control. Each 128-bit
148            // chunk is shuffled independently: this means that we view the vector as a
149            // sequence of 4-element arrays, and we shuffle each of these arrays, where
150            // `control` determines which element of the current `data` array is written.
151            "vpermilvar.ps" | "vpermilvar.ps.256" => {
152                let [data, control] = this.check_shim(abi, CanonAbi::C, link_name, args)?;
153
154                let (data, data_len) = this.project_to_simd(data)?;
155                let (control, control_len) = this.project_to_simd(control)?;
156                let (dest, dest_len) = this.project_to_simd(dest)?;
157
158                assert_eq!(dest_len, data_len);
159                assert_eq!(dest_len, control_len);
160
161                for i in 0..dest_len {
162                    let control = this.project_index(&control, i)?;
163
164                    // Each 128-bit chunk is shuffled independently. Since each chunk contains
165                    // four 32-bit elements, only two bits from `control` are used. To read the
166                    // value from the current chunk, add the destination index truncated to a multiple
167                    // of 4.
168                    let chunk_base = i & !0b11;
169                    let src_i = u64::from(this.read_scalar(&control)?.to_u32()? & 0b11)
170                        .strict_add(chunk_base);
171
172                    this.copy_op(
173                        &this.project_index(&data, src_i)?,
174                        &this.project_index(&dest, i)?,
175                    )?;
176                }
177            }
178            // Used to implement the _mm_permutevar_pd and _mm256_permutevar_pd functions.
179            // Shuffles 64-bit floats from `left` using `right` as control. Each 128-bit
180            // chunk is shuffled independently: this means that we view the vector as
181            // a sequence of 2-element arrays, and we shuffle each of these arrays,
182            // where `right` determines which element of the current `left` array is
183            // written.
184            "vpermilvar.pd" | "vpermilvar.pd.256" => {
185                let [data, control] = this.check_shim(abi, CanonAbi::C, link_name, args)?;
186
187                let (data, data_len) = this.project_to_simd(data)?;
188                let (control, control_len) = this.project_to_simd(control)?;
189                let (dest, dest_len) = this.project_to_simd(dest)?;
190
191                assert_eq!(dest_len, data_len);
192                assert_eq!(dest_len, control_len);
193
194                for i in 0..dest_len {
195                    let control = this.project_index(&control, i)?;
196
197                    // Each 128-bit chunk is shuffled independently. Since each chunk contains
198                    // two 64-bit elements, only the second bit from `control` is used (yes, the
199                    // second instead of the first, ask Intel). To read the value from the current
200                    // chunk, add the destination index truncated to a multiple of 2.
201                    let chunk_base = i & !1;
202                    let src_i =
203                        ((this.read_scalar(&control)?.to_u64()? >> 1) & 1).strict_add(chunk_base);
204
205                    this.copy_op(
206                        &this.project_index(&data, src_i)?,
207                        &this.project_index(&dest, i)?,
208                    )?;
209                }
210            }
211            // Used to implement the _mm256_permute2f128_ps, _mm256_permute2f128_pd and
212            // _mm256_permute2f128_si256 functions. Regardless of the suffix in the name
213            // thay all can be considered to operate on vectors of 128-bit elements.
214            // For each 128-bit element of `dest`, copies one from `left`, `right` or
215            // zero, according to `imm`.
216            "vperm2f128.ps.256" | "vperm2f128.pd.256" | "vperm2f128.si.256" => {
217                let [left, right, imm] = this.check_shim(abi, CanonAbi::C, link_name, args)?;
218
219                assert_eq!(dest.layout, left.layout);
220                assert_eq!(dest.layout, right.layout);
221                assert_eq!(dest.layout.size.bits(), 256);
222
223                // Transmute to `[u128; 2]` to process each 128-bit chunk independently.
224                let u128x2_layout =
225                    this.layout_of(Ty::new_array(this.tcx.tcx, this.tcx.types.u128, 2))?;
226                let left = left.transmute(u128x2_layout, this)?;
227                let right = right.transmute(u128x2_layout, this)?;
228                let dest = dest.transmute(u128x2_layout, this)?;
229
230                let imm = this.read_scalar(imm)?.to_u8()?;
231
232                for i in 0..2 {
233                    let dest = this.project_index(&dest, i)?;
234
235                    let imm = match i {
236                        0 => imm & 0xF,
237                        1 => imm >> 4,
238                        _ => unreachable!(),
239                    };
240                    if imm & 0b100 != 0 {
241                        this.write_scalar(Scalar::from_u128(0), &dest)?;
242                    } else {
243                        let src = match imm {
244                            0b00 => this.project_index(&left, 0)?,
245                            0b01 => this.project_index(&left, 1)?,
246                            0b10 => this.project_index(&right, 0)?,
247                            0b11 => this.project_index(&right, 1)?,
248                            _ => unreachable!(),
249                        };
250                        this.copy_op(&src, &dest)?;
251                    }
252                }
253            }
254            // Used to implement the _mm_maskload_ps, _mm_maskload_pd, _mm256_maskload_ps
255            // and _mm256_maskload_pd functions.
256            // For the element `i`, if the high bit of the `i`-th element of `mask`
257            // is one, it is loaded from `ptr.wrapping_add(i)`, otherwise zero is
258            // loaded.
259            "maskload.ps" | "maskload.pd" | "maskload.ps.256" | "maskload.pd.256" => {
260                let [ptr, mask] = this.check_shim(abi, CanonAbi::C, link_name, args)?;
261
262                mask_load(this, ptr, mask, dest)?;
263            }
264            // Used to implement the _mm_maskstore_ps, _mm_maskstore_pd, _mm256_maskstore_ps
265            // and _mm256_maskstore_pd functions.
266            // For the element `i`, if the high bit of the element `i`-th of `mask`
267            // is one, it is stored into `ptr.wapping_add(i)`.
268            // Unlike SSE2's _mm_maskmoveu_si128, these are not non-temporal stores.
269            "maskstore.ps" | "maskstore.pd" | "maskstore.ps.256" | "maskstore.pd.256" => {
270                let [ptr, mask, value] = this.check_shim(abi, CanonAbi::C, link_name, args)?;
271
272                mask_store(this, ptr, mask, value)?;
273            }
274            // Used to implement the _mm256_lddqu_si256 function.
275            // Reads a 256-bit vector from an unaligned pointer. This intrinsic
276            // is expected to perform better than a regular unaligned read when
277            // the data crosses a cache line, but for Miri this is just a regular
278            // unaligned read.
279            "ldu.dq.256" => {
280                let [src_ptr] = this.check_shim(abi, CanonAbi::C, link_name, args)?;
281                let src_ptr = this.read_pointer(src_ptr)?;
282                let dest = dest.force_mplace(this)?;
283
284                // Unaligned copy, which is what we want.
285                this.mem_copy(src_ptr, dest.ptr(), dest.layout.size, /*nonoverlapping*/ true)?;
286            }
287            // Used to implement the _mm256_testz_si256, _mm256_testc_si256 and
288            // _mm256_testnzc_si256 functions.
289            // Tests `op & mask == 0`, `op & mask == mask` or
290            // `op & mask != 0 && op & mask != mask`
291            "ptestz.256" | "ptestc.256" | "ptestnzc.256" => {
292                let [op, mask] = this.check_shim(abi, CanonAbi::C, link_name, args)?;
293
294                let (all_zero, masked_set) = test_bits_masked(this, op, mask)?;
295                let res = match unprefixed_name {
296                    "ptestz.256" => all_zero,
297                    "ptestc.256" => masked_set,
298                    "ptestnzc.256" => !all_zero && !masked_set,
299                    _ => unreachable!(),
300                };
301
302                this.write_scalar(Scalar::from_i32(res.into()), dest)?;
303            }
304            // Used to implement the _mm256_testz_pd, _mm256_testc_pd, _mm256_testnzc_pd
305            // _mm_testz_pd, _mm_testc_pd, _mm_testnzc_pd, _mm256_testz_ps,
306            // _mm256_testc_ps, _mm256_testnzc_ps, _mm_testz_ps, _mm_testc_ps and
307            // _mm_testnzc_ps functions.
308            // Calculates two booleans:
309            // `direct`, which is true when the highest bit of each element of `op & mask` is zero.
310            // `negated`, which is true when the highest bit of each element of `!op & mask` is zero.
311            // Return `direct` (testz), `negated` (testc) or `!direct & !negated` (testnzc)
312            "vtestz.pd.256" | "vtestc.pd.256" | "vtestnzc.pd.256" | "vtestz.pd" | "vtestc.pd"
313            | "vtestnzc.pd" | "vtestz.ps.256" | "vtestc.ps.256" | "vtestnzc.ps.256"
314            | "vtestz.ps" | "vtestc.ps" | "vtestnzc.ps" => {
315                let [op, mask] = this.check_shim(abi, CanonAbi::C, link_name, args)?;
316
317                let (direct, negated) = test_high_bits_masked(this, op, mask)?;
318                let res = match unprefixed_name {
319                    "vtestz.pd.256" | "vtestz.pd" | "vtestz.ps.256" | "vtestz.ps" => direct,
320                    "vtestc.pd.256" | "vtestc.pd" | "vtestc.ps.256" | "vtestc.ps" => negated,
321                    "vtestnzc.pd.256" | "vtestnzc.pd" | "vtestnzc.ps.256" | "vtestnzc.ps" =>
322                        !direct && !negated,
323                    _ => unreachable!(),
324                };
325
326                this.write_scalar(Scalar::from_i32(res.into()), dest)?;
327            }
328            // Used to implement the `_mm256_zeroupper` and `_mm256_zeroall` functions.
329            // These function clear out the upper 128 bits of all avx registers or
330            // zero out all avx registers respectively.
331            "vzeroupper" | "vzeroall" => {
332                // These functions are purely a performance hint for the CPU.
333                // Any registers currently in use will be saved beforehand by the
334                // compiler, making these functions no-ops.
335
336                // The only thing that needs to be ensured is the correct calling convention.
337                let [] = this.check_shim(abi, CanonAbi::C, link_name, args)?;
338            }
339            _ => return interp_ok(EmulateItemResult::NotSupported),
340        }
341        interp_ok(EmulateItemResult::NeedsReturn)
342    }
343}