Fist implementation of the satellite
This commit is contained in:
198
src/wyoming_client/satellite.py
Normal file
198
src/wyoming_client/satellite.py
Normal file
@@ -0,0 +1,198 @@
|
||||
from typing import Optional
|
||||
|
||||
from wyoming.asr import Transcript
|
||||
from wyoming.audio import AudioStart, AudioChunk, AudioStop
|
||||
from wyoming.asr import Transcribe
|
||||
from wyoming.client import AsyncTcpClient
|
||||
|
||||
from ..wyoming_client.audio_buffer import AudioBuffer
|
||||
from ..wyoming_client.vad import VADDetector
|
||||
from queue import Queue, Empty
|
||||
import typer
|
||||
import asyncio
|
||||
import threading
|
||||
import numpy as np
|
||||
import sounddevice as sd
|
||||
import time
|
||||
|
||||
|
||||
class SatelliteController:
|
||||
"""Main satellite controller with VAD-based audio streaming."""
|
||||
|
||||
def __init__(self, host: str, port: int, lang: str, vad_detector: VADDetector,
|
||||
chunk_duration: float = 0.03, timeout: float = 5.0):
|
||||
self.host = host
|
||||
self.port = port
|
||||
self.lang = lang
|
||||
self.vad_detector = vad_detector
|
||||
self.timeout = timeout
|
||||
|
||||
# Audio settings
|
||||
self.sample_rate = 16000
|
||||
self.channels = 1
|
||||
self.chunk_size = int(chunk_duration * self.sample_rate)
|
||||
|
||||
# Components
|
||||
self.audio_buffer = AudioBuffer(max_duration=1.0, sample_rate=self.sample_rate)
|
||||
|
||||
# State
|
||||
self.is_running = False
|
||||
self.is_speaking = False
|
||||
self.audio_queue = Queue()
|
||||
self.transcription_queue = Queue()
|
||||
|
||||
def _audio_callback(self, indata, frames, time, status):
|
||||
"""Callback for sounddevice audio stream."""
|
||||
if status:
|
||||
typer.echo(f"Audio callback status: {status}")
|
||||
|
||||
audio_chunk = indata[:, 0].copy() # Extract mono channel
|
||||
self.audio_queue.put(audio_chunk)
|
||||
|
||||
async def _async_transcribe(self, pcm_bytes: bytes) -> Optional[str]:
|
||||
"""Stream raw PCM data to Wyoming ASR and return transcript text."""
|
||||
# Instantiate the async TCP client
|
||||
client = AsyncTcpClient(self.host, self.port)
|
||||
|
||||
# Audio parameters
|
||||
rate = 16000
|
||||
width = 2 # 16-bit
|
||||
channels = 1
|
||||
|
||||
# The client instance is an async context manager.
|
||||
async with client:
|
||||
# 1. Send transcription request
|
||||
await client.write_event(Transcribe(language=self.lang).event())
|
||||
|
||||
# 2. Start the audio stream
|
||||
await client.write_event(AudioStart(rate, width, channels).event())
|
||||
|
||||
# 3. Send audio chunks
|
||||
chunk_size = 2048 # A reasonable chunk size
|
||||
for i in range(0, len(pcm_bytes), chunk_size):
|
||||
chunk_bytes = pcm_bytes[i:i + chunk_size]
|
||||
await client.write_event(AudioChunk(audio=chunk_bytes, rate=rate, width=width, channels=channels).event())
|
||||
|
||||
# 4. Stop the audio stream
|
||||
await client.write_event(AudioStop().event())
|
||||
|
||||
# 5. Read events until a transcript arrives
|
||||
transcript_text = None
|
||||
try:
|
||||
while True:
|
||||
event = await asyncio.wait_for(client.read_event(), timeout=self.timeout)
|
||||
if event is None:
|
||||
break
|
||||
|
||||
if Transcript.is_type(event.type):
|
||||
tr = Transcript.from_event(event)
|
||||
transcript_text = tr.text
|
||||
break
|
||||
except asyncio.TimeoutError:
|
||||
typer.echo(typer.style("Connection timed out waiting for transcript.", fg=typer.colors.YELLOW))
|
||||
|
||||
return transcript_text
|
||||
|
||||
def _process_audio(self):
|
||||
"""Process audio chunks with VAD in separate thread."""
|
||||
while self.is_running:
|
||||
try:
|
||||
audio_chunk = self.audio_queue.get(timeout=0.1)
|
||||
|
||||
# Add to buffer
|
||||
self.audio_buffer.add_chunk(audio_chunk)
|
||||
|
||||
# Check VAD
|
||||
speech_detected = self.vad_detector.is_speech(audio_chunk)
|
||||
|
||||
if speech_detected and not self.is_speaking:
|
||||
# Start of speech
|
||||
typer.echo(typer.style("🎤 Speech detected", fg=typer.colors.GREEN))
|
||||
self.is_speaking = True
|
||||
pre_buffer = self.audio_buffer.start_recording()
|
||||
|
||||
# Start async transcription in background
|
||||
threading.Thread(
|
||||
target=self._start_transcription,
|
||||
daemon=True
|
||||
).start()
|
||||
|
||||
elif not speech_detected and self.is_speaking:
|
||||
# End of speech
|
||||
typer.echo(typer.style("🔇 Speech ended", fg=typer.colors.YELLOW))
|
||||
self.is_speaking = False
|
||||
full_recording = self.audio_buffer.stop_recording()
|
||||
|
||||
if len(full_recording) > 0:
|
||||
# Queue for transcription
|
||||
self.transcription_queue.put(full_recording)
|
||||
|
||||
except Empty:
|
||||
continue
|
||||
except Exception as e:
|
||||
typer.echo(typer.style(f"Audio processing error: {e}", fg=typer.colors.RED))
|
||||
|
||||
def _start_transcription(self):
|
||||
"""Handle transcription in background thread."""
|
||||
try:
|
||||
# Wait for audio to be queued
|
||||
recording = self.transcription_queue.get(timeout=2.0)
|
||||
|
||||
# Convert to PCM16
|
||||
audio_int16 = np.clip(recording * 32767.0, -32768, 32767).astype(np.int16)
|
||||
pcm_bytes = audio_int16.tobytes()
|
||||
|
||||
# Send to Wyoming ASR
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
|
||||
try:
|
||||
transcript_text = loop.run_until_complete(
|
||||
self._async_transcribe(pcm_bytes)
|
||||
)
|
||||
|
||||
if transcript_text:
|
||||
typer.echo(typer.style(f"📝 {transcript_text}", fg=typer.colors.CYAN, bold=True))
|
||||
else:
|
||||
typer.echo(typer.style("❌ No transcription received", fg=typer.colors.YELLOW))
|
||||
finally:
|
||||
loop.close()
|
||||
|
||||
except Empty:
|
||||
typer.echo(typer.style("⏰ Transcription timeout", fg=typer.colors.YELLOW))
|
||||
except Exception as e:
|
||||
typer.echo(typer.style(f"❌ Transcription error: {e}", fg=typer.colors.RED))
|
||||
|
||||
def run(self):
|
||||
"""Run the satellite."""
|
||||
typer.echo(typer.style("🛰️ Starting satellite mode...", fg=typer.colors.BLUE, bold=True))
|
||||
typer.echo(f"Listening on default microphone ({self.sample_rate} Hz, {self.channels} ch)")
|
||||
typer.echo(f"Wyoming server: {self.host}:{self.port} (lang: {self.lang})")
|
||||
typer.echo("Press Ctrl+C to stop")
|
||||
typer.echo("=" * 50)
|
||||
|
||||
self.is_running = True
|
||||
|
||||
# Start audio processing thread
|
||||
audio_thread = threading.Thread(target=self._process_audio, daemon=True)
|
||||
audio_thread.start()
|
||||
|
||||
try:
|
||||
# Start audio stream
|
||||
with sd.InputStream(
|
||||
callback=self._audio_callback,
|
||||
samplerate=self.sample_rate,
|
||||
channels=self.channels,
|
||||
blocksize=self.chunk_size,
|
||||
dtype='float32'
|
||||
):
|
||||
# Keep running until interrupted
|
||||
while self.is_running:
|
||||
time.sleep(0.1)
|
||||
|
||||
except KeyboardInterrupt:
|
||||
typer.echo(typer.style("\n🛑 Stopping satellite...", fg=typer.colors.YELLOW))
|
||||
except Exception as e:
|
||||
typer.echo(typer.style(f"❌ Satellite error: {e}", fg=typer.colors.RED))
|
||||
finally:
|
||||
self.is_running = False
|
||||
Reference in New Issue
Block a user