summary refs log tree commit diff
path: root/src/subtitle_extraction/whisper.rs
blob: ffa2e47240322963a7126439a000a8d716d31838 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
use std::{
    io::{self, BufRead, BufReader},
    net::{TcpListener, TcpStream},
    sync::mpsc,
};

use anyhow::Context;
use ffmpeg::{filter, frame};
use serde::Deserialize;

use crate::{subtitle_extraction::*, tracks::StreamIndex};

#[derive(Debug, Deserialize)]
struct WhisperCue {
    start: u64,
    end: u64,
    text: String,
}

pub fn generate_whisper_subtitles(
    // stream index to use when storing generated subtitles, this index
    // already has to be in TRACKS when this function is called!
    stream_ix: StreamIndex,
    context: ffmpeg::codec::Context,
    time_base: ffmpeg::Rational,
    packet_rx: mpsc::Receiver<ffmpeg::Packet>,
    sender: ComponentSender<SubtitleExtractor>,
) -> anyhow::Result<()> {
    // FFmpeg's whisper filter will send the generated subtitles to us as JSON
    // objects over a TCP socket. This is the best solution I could find
    // because we need to use one of the protocols in
    // https://ffmpeg.org/ffmpeg-protocols.html, and TCP is the only one on the
    // list which is portable and supports non-blocking IO in Rust.
    let tcp_listener = TcpListener::bind("127.0.0.1:0")?;

    let mut decoder = context
        .decoder()
        .audio()
        .with_context(|| format!("error creating subtitle decoder for stream {}", stream_ix))?;

    let mut filter = filter::Graph::new();

    let abuffer_args = format!(
        "time_base={}:sample_rate={}:sample_fmt={}:channel_layout=0x{:x}",
        time_base,
        decoder.rate(),
        decoder.format().name(),
        decoder.channel_layout().bits()
    );

    let whisper_args = format!(
        "model={}:queue={}:destination=tcp\\\\://127.0.0.1\\\\:{}:format=json",
        "/Users/malte/repos/lleap/whisper-models/ggml-large-v3.bin",
        30,
        tcp_listener.local_addr()?.port()
    );
    let filter_spec = format!("[src] whisper={} [sink]", whisper_args);

    filter.add(&filter::find("abuffer").unwrap(), "src", &abuffer_args)?;
    filter.add(&filter::find("abuffersink").unwrap(), "sink", "")?;
    filter
        .output("src", 0)?
        .input("sink", 0)?
        .parse(&filter_spec)?;
    filter.validate()?;

    let mut source_ctx = filter.get("src").unwrap();
    let mut sink_ctx = filter.get("sink").unwrap();

    let (tcp_stream, _) = tcp_listener.accept()?;
    tcp_stream.set_nonblocking(true)?;

    let mut transcript_reader = BufReader::new(tcp_stream);
    let mut line_buf = String::new();

    while let Ok(packet) = packet_rx.recv() {
        handle_packet(
            stream_ix,
            &sender,
            &mut decoder,
            source_ctx.source(),
            sink_ctx.sink(),
            &mut transcript_reader,
            &mut line_buf,
            packet,
        )
        .unwrap_or_else(|e| log::error!("error handling audio packet: {}", e))
    }

    Ok(())
}

// TODO: can we do this without passing all the arguments? this is kinda ugly
fn handle_packet(
    stream_ix: StreamIndex,
    sender: &ComponentSender<SubtitleExtractor>,
    decoder: &mut ffmpeg::decoder::Audio,
    mut source: filter::Source,
    mut sink: filter::Sink,
    transcript_reader: &mut BufReader<TcpStream>,
    line_buf: &mut String,
    packet: ffmpeg::Packet,
) -> anyhow::Result<()> {
    decoder.send_packet(&packet)?;

    let mut decoded = frame::Audio::empty();
    while decoder.receive_frame(&mut decoded).is_ok() {
        source.add(&decoded)?;
    }

    let mut out_frame = frame::Audio::empty();
    while sink.frame(&mut out_frame).is_ok() {}

    line_buf.clear();
    match transcript_reader.read_line(line_buf) {
        Ok(_) => {
            let whisper_cue: WhisperCue = serde_json::from_str(&line_buf)?;

            let cue = SubtitleCue {
                start: gst::ClockTime::from_mseconds(whisper_cue.start),
                end: gst::ClockTime::from_mseconds(whisper_cue.end),
                text: whisper_cue.text,
            };

            // TODO deduplicate this vs. the code in embedded.rs
            SUBTITLE_TRACKS
                .write()
                .get_mut(&stream_ix)
                .unwrap()
                .cues
                .push(cue.clone());
            sender
                .output(SubtitleExtractorOutput::NewCue(stream_ix, cue))
                .unwrap();

            Ok(())
        }
        Err(e) => match e.kind() {
            io::ErrorKind::WouldBlock => Ok(()),
            _ => Err(e)?,
        },
    }
}