yolo_binding\core/
predict.rs

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
use super::{YOLODevice, YOLO};
use std::error::Error;
// use std::time::Instant;
use tch::{Device, Tensor};

pub fn pred(model: &YOLO, input: &Tensor) -> Result<Tensor, Box<dyn Error>> {
    //! Predict the input tensor
    //! model: the YOLO model
    //! input: the input tensor
    //! return: the output tensor
    let device = match model.cuda {
        YOLODevice::Cpu => Device::Cpu,
        YOLODevice::Gpu => Device::cuda_if_available(),
    };
    let input = input.to_device(device);
    // let time_start = Instant::now();
    let output = model
        .yolo_model
        .forward_ts(&[input])
        .expect("forward failed");
    // let time_end = Instant::now();
    // println!("Inference time: {:?}", time_end.duration_since(time_start));
    Ok(output)
}