diff --git a/src-tauri/Cargo.toml b/src-tauri/Cargo.toml index f2d2826..31f292f 100644 --- a/src-tauri/Cargo.toml +++ b/src-tauri/Cargo.toml @@ -52,8 +52,15 @@ windows = { version = "0.62", features = [ "Win32_System_Com", "Win32_Foundation", "Win32_Media_Audio", - "Win32_Media_Audio_Endpoints", + "Win32_Devices_Properties", + "Win32_Media_KernelStreaming", + "Win32_System_Com_StructuredStorage", "Win32_System_Threading", + "Win32_Security", + "Win32_System_SystemServices", + "Win32_System_Variant", + "Win32_Media_Multimedia", + "Win32_UI_Shell_PropertiesSystem", ] } [patch.crates-io] diff --git a/src-tauri/src/ml/audio.rs b/src-tauri/src/ml/audio.rs index 6eb8ea4..67ea1ae 100644 --- a/src-tauri/src/ml/audio.rs +++ b/src-tauri/src/ml/audio.rs @@ -2,12 +2,15 @@ //! //! This module handles audio recording in a thread-safe way by using //! a shared buffer that can be accessed from the Tauri state. +//! +//! On Windows, it captures both microphone input AND system audio (loopback) +//! to record both sides of a meeting conversation. use parking_lot::Mutex; use std::sync::Arc; use std::thread::{self, JoinHandle}; use thiserror::Error; -use tracing::{debug, error, info, warn}; +use tracing::{error, info, warn}; /// Sample rate expected by Whisper (16kHz mono) pub const WHISPER_SAMPLE_RATE: u32 = 16000; @@ -33,6 +36,8 @@ pub struct AudioCapture { is_recording: Arc>, should_stop: Arc>, recording_thread: Option>, + #[cfg(target_os = "windows")] + loopback_capture: Option, } // Implement Send + Sync manually since our struct only contains thread-safe types @@ -47,10 +52,13 @@ impl AudioCapture { is_recording: Arc::new(Mutex::new(false)), should_stop: Arc::new(Mutex::new(false)), recording_thread: None, + #[cfg(target_os = "windows")] + loopback_capture: None, }) } /// Start recording audio in a background thread. + /// Captures both microphone input and system audio (loopback) on Windows. pub fn start_recording(&mut self) -> Result<(), AudioError> { if *self.is_recording.lock() { warn!("Already recording"); @@ -62,82 +70,83 @@ impl AudioCapture { *self.should_stop.lock() = false; *self.is_recording.lock() = true; + // Start WASAPI loopback capture for system audio (Windows only) + #[cfg(target_os = "windows")] + { + let mut loopback = super::wasapi_loopback::WasapiLoopback::new(); + if let Err(e) = loopback.start_capture(Arc::clone(&self.buffer)) { + warn!("Failed to start WASAPI loopback: {}", e); + } else { + info!("WASAPI loopback capture initialized"); + } + self.loopback_capture = Some(loopback); + } + let buffer = Arc::clone(&self.buffer); let is_recording = Arc::clone(&self.is_recording); let should_stop = Arc::clone(&self.should_stop); - // Spawn a thread to handle audio capture + // Spawn a thread to handle microphone capture via cpal let handle = thread::spawn(move || { use cpal::traits::{DeviceTrait, HostTrait, StreamTrait}; - use cpal::{SampleRate, StreamConfig}; let host = cpal::default_host(); - let device = match host.default_input_device() { - Some(d) => d, - None => { - error!("No input device available"); - *is_recording.lock() = false; - return; - } - }; + // Get microphone (input device) + let input_device = host.default_input_device(); - info!("Using audio input device: {}", device.name().unwrap_or_default()); - - // Get supported config - let supported_config = match device.default_input_config() { - Ok(c) => c, - Err(e) => { - error!("Failed to get input config: {}", e); - *is_recording.lock() = false; - return; - } - }; - - debug!("Supported config: {:?}", supported_config); - - // We want 16kHz mono for Whisper - let config = StreamConfig { - channels: 1, - sample_rate: SampleRate(WHISPER_SAMPLE_RATE), - buffer_size: cpal::BufferSize::Default, - }; - - let buffer_clone = Arc::clone(&buffer); - let err_fn = |err| error!("Audio stream error: {}", err); - - // Build the input stream - let stream = match device.build_input_stream( - &config, - move |data: &[f32], _: &cpal::InputCallbackInfo| { - buffer_clone.lock().extend_from_slice(data); - }, - err_fn, - None, - ) { - Ok(s) => s, - Err(e) => { - error!("Failed to build input stream: {}", e); - *is_recording.lock() = false; - return; - } - }; - - if let Err(e) = stream.play() { - error!("Failed to start stream: {}", e); - *is_recording.lock() = false; - return; + if input_device.is_none() { + warn!("No microphone available - only system audio will be captured"); } - info!("Audio recording started"); + let mut streams: Vec = Vec::new(); - // Keep the stream alive until stop is requested + // Set up microphone capture + if let Some(ref device) = input_device { + info!("Microphone device: {}", device.name().unwrap_or_default()); + + if let Ok(config) = device.default_input_config() { + let sample_rate = config.sample_rate().0; + let channels = config.channels(); + let sample_format = config.sample_format(); + info!("Microphone config: {} Hz, {} ch, {:?}", sample_rate, channels, sample_format); + + let buffer_clone = Arc::clone(&buffer); + let stream_config: cpal::StreamConfig = config.into(); + + let stream = build_input_stream( + device, + &stream_config, + sample_format, + sample_rate, + channels, + buffer_clone, + "mic", + ); + + if let Some(s) = stream { + if s.play().is_ok() { + info!("Microphone capture started"); + streams.push(s); + } + } + } + } + + // Note: Even if mic fails, WASAPI loopback may still be capturing system audio + if streams.is_empty() { + warn!("No microphone stream started - relying on WASAPI loopback for system audio"); + } else { + info!("Audio recording started with {} microphone stream(s)", streams.len()); + } + + // Keep the streams alive until stop is requested while !*should_stop.lock() { thread::sleep(std::time::Duration::from_millis(100)); } - // Stream is automatically stopped when dropped - drop(stream); + // Streams are automatically stopped when dropped + drop(streams); *is_recording.lock() = false; info!("Audio recording stopped"); }); @@ -151,7 +160,13 @@ impl AudioCapture { // Signal the thread to stop *self.should_stop.lock() = true; - // Wait for the thread to finish + // Stop WASAPI loopback capture + #[cfg(target_os = "windows")] + if let Some(ref mut loopback) = self.loopback_capture { + loopback.stop_capture(); + } + + // Wait for the microphone thread to finish if let Some(handle) = self.recording_thread.take() { let _ = handle.join(); } @@ -201,6 +216,96 @@ impl Drop for AudioCapture { } } +/// Build an input stream for the given device with automatic format handling. +fn build_input_stream( + device: &cpal::Device, + config: &cpal::StreamConfig, + sample_format: cpal::SampleFormat, + sample_rate: u32, + channels: u16, + buffer: Arc>>, + source_name: &'static str, +) -> Option { + use cpal::traits::DeviceTrait; + + let err_fn = move |err| error!("Audio stream error ({}): {}", source_name, err); + + // Create a processing closure that handles mono conversion and resampling + let make_processor = move || { + let buffer = Arc::clone(&buffer); + move |samples: Vec| { + // Convert to mono if stereo + let mono_samples: Vec = if channels > 1 { + samples + .chunks(channels as usize) + .map(|chunk| chunk.iter().sum::() / channels as f32) + .collect() + } else { + samples + }; + + // Resample to 16kHz if needed + let resampled = if sample_rate != WHISPER_SAMPLE_RATE { + resample(&mono_samples, sample_rate, WHISPER_SAMPLE_RATE) + } else { + mono_samples + }; + + buffer.lock().extend_from_slice(&resampled); + } + }; + + let stream = match sample_format { + cpal::SampleFormat::F32 => { + let process = make_processor(); + device.build_input_stream( + config, + move |data: &[f32], _: &cpal::InputCallbackInfo| { + process(data.to_vec()); + }, + err_fn, + None, + ) + } + cpal::SampleFormat::I16 => { + let process = make_processor(); + device.build_input_stream( + config, + move |data: &[i16], _: &cpal::InputCallbackInfo| { + let samples: Vec = data.iter().map(|&s| s as f32 / 32768.0).collect(); + process(samples); + }, + err_fn, + None, + ) + } + cpal::SampleFormat::I32 => { + let process = make_processor(); + device.build_input_stream( + config, + move |data: &[i32], _: &cpal::InputCallbackInfo| { + let samples: Vec = data.iter().map(|&s| s as f32 / 2147483648.0).collect(); + process(samples); + }, + err_fn, + None, + ) + } + format => { + error!("Unsupported sample format for {}: {:?}", source_name, format); + return None; + } + }; + + match stream { + Ok(s) => Some(s), + Err(e) => { + warn!("Failed to build {} stream: {}", source_name, e); + None + } + } +} + /// Convert audio samples from i16 to f32 format. pub fn i16_to_f32(samples: &[i16]) -> Vec { samples.iter().map(|&s| s as f32 / 32768.0).collect() diff --git a/src-tauri/src/ml/mod.rs b/src-tauri/src/ml/mod.rs index f994797..f24b02f 100644 --- a/src-tauri/src/ml/mod.rs +++ b/src-tauri/src/ml/mod.rs @@ -11,6 +11,9 @@ pub mod transcriber; pub mod summarizer; pub mod vad; +#[cfg(target_os = "windows")] +pub mod wasapi_loopback; + pub use audio::AudioCapture; pub use transcriber::WhisperTranscriber; pub use summarizer::LlamaSummarizer; diff --git a/src-tauri/src/ml/wasapi_loopback.rs b/src-tauri/src/ml/wasapi_loopback.rs new file mode 100644 index 0000000..574a088 --- /dev/null +++ b/src-tauri/src/ml/wasapi_loopback.rs @@ -0,0 +1,311 @@ +//! WASAPI loopback capture for Windows. +//! +//! This module captures system audio (what's playing through speakers) +//! using Windows Audio Session API (WASAPI) in loopback mode. + +use parking_lot::Mutex; +use std::sync::Arc; +use std::thread::{self, JoinHandle}; +use tracing::{error, info, warn}; +use windows::Win32::Media::Audio; +use windows::Win32::System::Com; + +/// WASAPI loopback capture state. +pub struct WasapiLoopback { + is_capturing: Arc>, + should_stop: Arc>, + capture_thread: Option>, +} + +impl WasapiLoopback { + /// Create a new WASAPI loopback capture instance. + pub fn new() -> Self { + Self { + is_capturing: Arc::new(Mutex::new(false)), + should_stop: Arc::new(Mutex::new(false)), + capture_thread: None, + } + } + + /// Start capturing system audio in loopback mode. + pub fn start_capture(&mut self, output_buffer: Arc>>) -> Result<(), String> { + if *self.is_capturing.lock() { + warn!("WASAPI loopback already capturing"); + return Ok(()); + } + + *self.should_stop.lock() = false; + *self.is_capturing.lock() = true; + + let is_capturing = Arc::clone(&self.is_capturing); + let should_stop = Arc::clone(&self.should_stop); + + let handle = thread::spawn(move || { + if let Err(e) = capture_loopback_audio(output_buffer, should_stop.clone()) { + error!("WASAPI loopback capture error: {}", e); + } + *is_capturing.lock() = false; + info!("WASAPI loopback capture stopped"); + }); + + self.capture_thread = Some(handle); + Ok(()) + } + + /// Stop capturing. + pub fn stop_capture(&mut self) { + *self.should_stop.lock() = true; + if let Some(handle) = self.capture_thread.take() { + let _ = handle.join(); + } + } + + /// Check if currently capturing. + pub fn is_capturing(&self) -> bool { + *self.is_capturing.lock() + } +} + +impl Drop for WasapiLoopback { + fn drop(&mut self) { + self.stop_capture(); + } +} + +/// Target sample rate for Whisper (16kHz) +const TARGET_SAMPLE_RATE: u32 = 16000; + +/// Capture loopback audio from the default render device. +fn capture_loopback_audio( + buffer: Arc>>, + should_stop: Arc>, +) -> Result<(), String> { + unsafe { + // Initialize COM + Com::CoInitializeEx(None, Com::COINIT_MULTITHREADED) + .ok() + .map_err(|e| format!("Failed to initialize COM: {}", e))?; + + // Create device enumerator + let enumerator: Audio::IMMDeviceEnumerator = + Com::CoCreateInstance(&Audio::MMDeviceEnumerator, None, Com::CLSCTX_ALL) + .map_err(|e| format!("Failed to create device enumerator: {}", e))?; + + // Get default render (output) device - this is key for loopback! + let device = enumerator + .GetDefaultAudioEndpoint(Audio::eRender, Audio::eConsole) + .map_err(|e| format!("Failed to get default render device: {}", e))?; + + // Get device name for logging + if let Ok(id) = device.GetId() { + info!("WASAPI loopback device: {:?}", id.to_string()); + } + + // Activate audio client + let audio_client: Audio::IAudioClient = device + .Activate(Com::CLSCTX_ALL, None) + .map_err(|e| format!("Failed to activate audio client: {}", e))?; + + // Get the mix format (what the device is actually using) + let mix_format = audio_client + .GetMixFormat() + .map_err(|e| format!("Failed to get mix format: {}", e))?; + + let format = &*mix_format; + let sample_rate = format.nSamplesPerSec; + let channels = format.nChannels; + let bits_per_sample = format.wBitsPerSample; + let block_align = format.nBlockAlign; + + info!( + "WASAPI loopback format: {} Hz, {} ch, {} bits", + sample_rate, channels, bits_per_sample + ); + + // Initialize audio client in loopback mode + // Key flags: AUDCLNT_STREAMFLAGS_LOOPBACK for capturing output + // Must use shared mode (not exclusive) for loopback + let buffer_duration = 10_000_000i64; // 1 second in 100-nanosecond units + + audio_client + .Initialize( + Audio::AUDCLNT_SHAREMODE_SHARED, + Audio::AUDCLNT_STREAMFLAGS_LOOPBACK, + buffer_duration, + 0, + mix_format, + None, + ) + .map_err(|e| format!("Failed to initialize audio client: {}", e))?; + + // Get capture client + let capture_client: Audio::IAudioCaptureClient = audio_client + .GetService() + .map_err(|e| format!("Failed to get capture client: {}", e))?; + + // Start capturing + audio_client + .Start() + .map_err(|e| format!("Failed to start audio client: {}", e))?; + + info!("WASAPI loopback capture started"); + + // Capture loop - use polling since event mode doesn't work for loopback + while !*should_stop.lock() { + // Sleep a bit to avoid busy-waiting (10ms = 100Hz polling) + thread::sleep(std::time::Duration::from_millis(10)); + + // Get available frames + let frames_available = match capture_client.GetNextPacketSize() { + Ok(frames) => frames, + Err(e) => { + warn!("Failed to get packet size: {}", e); + continue; + } + }; + + if frames_available == 0 { + continue; + } + + // Get buffer + let mut data_ptr: *mut u8 = std::ptr::null_mut(); + let mut num_frames: u32 = 0; + let mut flags: u32 = 0; + + if let Err(e) = capture_client.GetBuffer( + &mut data_ptr, + &mut num_frames, + &mut flags, + None, + None, + ) { + warn!("Failed to get buffer: {}", e); + continue; + } + + if num_frames > 0 && !data_ptr.is_null() { + // Convert to f32 samples + let samples = convert_to_f32( + data_ptr, + num_frames as usize, + channels as usize, + bits_per_sample, + block_align as usize, + ); + + // Convert to mono + let mono_samples = to_mono(&samples, channels as usize); + + // Resample to 16kHz if needed + let resampled = if sample_rate != TARGET_SAMPLE_RATE { + resample(&mono_samples, sample_rate, TARGET_SAMPLE_RATE) + } else { + mono_samples + }; + + // Add to buffer + buffer.lock().extend_from_slice(&resampled); + } + + // Release buffer + if let Err(e) = capture_client.ReleaseBuffer(num_frames) { + warn!("Failed to release buffer: {}", e); + } + } + + // Stop and cleanup + let _ = audio_client.Stop(); + + Ok(()) + } +} + +/// Convert raw audio bytes to f32 samples. +fn convert_to_f32( + data: *mut u8, + num_frames: usize, + channels: usize, + bits_per_sample: u16, + block_align: usize, +) -> Vec { + let total_samples = num_frames * channels; + let mut samples = Vec::with_capacity(total_samples); + + unsafe { + match bits_per_sample { + 16 => { + let ptr = data as *const i16; + for i in 0..total_samples { + let sample = *ptr.add(i); + samples.push(sample as f32 / 32768.0); + } + } + 32 => { + // Could be f32 or i32 - WASAPI mix format is usually f32 + let ptr = data as *const f32; + for i in 0..total_samples { + samples.push(*ptr.add(i)); + } + } + 24 => { + // 24-bit samples packed in 3 bytes + for i in 0..total_samples { + let offset = (i / channels) * block_align + (i % channels) * 3; + let b0 = *data.add(offset) as i32; + let b1 = *data.add(offset + 1) as i32; + let b2 = *data.add(offset + 2) as i32; + let sample = (b2 << 16) | (b1 << 8) | b0; + // Sign extend from 24 to 32 bits + let sample = if sample & 0x800000 != 0 { + sample | 0xFF000000u32 as i32 + } else { + sample + }; + samples.push(sample as f32 / 8388608.0); + } + } + _ => { + warn!("Unsupported bits per sample: {}", bits_per_sample); + } + } + } + + samples +} + +/// Convert multi-channel audio to mono by averaging channels. +fn to_mono(samples: &[f32], channels: usize) -> Vec { + if channels == 1 { + return samples.to_vec(); + } + + samples + .chunks(channels) + .map(|chunk| chunk.iter().sum::() / channels as f32) + .collect() +} + +/// Simple linear interpolation resampling. +fn resample(samples: &[f32], from_rate: u32, to_rate: u32) -> Vec { + if from_rate == to_rate || samples.is_empty() { + return samples.to_vec(); + } + + let ratio = to_rate as f64 / from_rate as f64; + let new_len = (samples.len() as f64 * ratio) as usize; + let mut output = Vec::with_capacity(new_len); + + for i in 0..new_len { + let src_idx = i as f64 / ratio; + let src_idx_floor = src_idx.floor() as usize; + let src_idx_ceil = (src_idx_floor + 1).min(samples.len() - 1); + let frac = src_idx - src_idx_floor as f64; + + let sample = + samples[src_idx_floor] as f64 * (1.0 - frac) + samples[src_idx_ceil] as f64 * frac; + output.push(sample as f32); + } + + output +}