You are on page 1of 3

import numpy as np

import matplotlib.pyplot as plt


import soundfile as sf

# Load audio file


file_path = "/content/gostopaudio.wav"
audio_signal, fs = sf.read(file_path)

# Convert stereo to mono if needed


if audio_signal.ndim == 2:
audio_signal = np.mean(audio_signal, axis=1)

# Add Gaussian noise to the audio signal


noise_level = 0.1
noise = noise_level * np.random.randn(len(audio_signal))
noisy_audio_signal = audio_signal + noise

# NLMS filter parameters


mu = 0.01 # Step size
order = 32 # Filter order
delta = 1e-5 # Small constant to avoid division by zero

# NLMS filter function


def nlms_filter(input_signal, desired_signal, order, mu, delta):
num_samples = len(input_signal)
weights = np.zeros(order)
output_signal = np.zeros(num_samples)
mse_values = np.zeros(num_samples - order)

for i in range(order, num_samples):


x = input_signal[i-order:i]
y_hat = np.dot(weights, x)
error = desired_signal[i] - y_hat
norm_factor = np.dot(x, x) + delta
weights = weights + (mu / norm_factor) * error * x
output_signal[i] = y_hat
mse_values[i - order] = np.mean((desired_signal[i-order+1:i+1]
- y_hat) ** 2)

return output_signal, mse_values

# Apply NLMS filter to the noisy audio signal


desired_signal = audio_signal # Using the original audio as the
desired signal for denoising
filtered_audio, mse_values = nlms_filter(noisy_audio_signal,
desired_signal, order, mu, delta)

# Calculate the error between the input audio and the filtered signal
error = audio_signal - filtered_audio
# Plotting
plt.figure(figsize=(12, 10))

plt.subplot(5, 1, 1)
plt.plot(audio_signal, label='Original Audio Signal')
plt.title('Original Audio Signal')
plt.xlabel('Sample')
plt.ylabel('Amplitude')
plt.legend()

plt.subplot(5, 1, 2)
plt.plot(noisy_audio_signal, label='Noisy Audio Signal', alpha=0.7)
plt.title('Noisy Audio Signal')
plt.xlabel('Sample')
plt.ylabel('Amplitude')
plt.legend()

plt.subplot(5, 1, 3)
plt.plot(filtered_audio, label='Filtered Audio Signal (NLMS)',
linestyle='dashed')
plt.title('Filtered Audio Signal (NLMS)')
plt.xlabel('Sample')
plt.ylabel('Amplitude')
plt.legend()

plt.subplot(5, 1, 4)
plt.plot(error, label='Error', color='red')
plt.title('Error Between Original and Filtered Signals')
plt.xlabel('Sample')
plt.ylabel('Amplitude')
plt.legend()

plt.subplot(5, 1, 5)
plt.plot(mse_values, label='Time Domain MSE', color='green')
plt.title('Time Domain MSE')
plt.xlabel('Sample')
plt.ylabel('MSE')
plt.legend()

plt.tight_layout()
plt.show()

You might also like