Skip to main content

How ML Models Search Audio with Text Queries

· 5 min read

In this post, I will be explaining everything you need to know about the Language-Based Audio Retrieval architecture. This architecture is designed to retrieve audio clips based on natural language queries. It uses a combination of audio and text encoders to achieve this.

Imagine we have a database of audio files, such as this:

image of database

We would then also have a csv file with the captions, as such:

Beach and Birds.wav

  • The sea gently washes up onto the shore.
  • The water is gently flowing past while birds chirp in the background
  • Water is gently flowing past while birds chirp in the background
  • Water is running as birds chirp in the background.
  • Water sound and birds are chirping in the background.

wind-quiberon01.wav

  • Ambient noises of a forest with birds chirping occasionally.
  • Soothing ambient tunes of a forest while the occasional bird chirps.
  • The wind is blowing but besides that it is pretty quiet
  • The wind is howling and get louder as time progresses.
  • The wind is howling and gets louder as time progresses.

Forder Viaduct.wav

  • Birds are singing out in nature while a large vehicle is passing nearby.
  • Many cars drive beneath the underpass on a busy highway
  • Traffic is passing, and birds are chirping as water drips nearby.
  • Traffic is passing, birds are chirping and water is dripping nearby.
  • out in nature, a large vehicle passing near by, birds singing

5am_voegel_abudhabi_park.wav

  • A cricket chirping while birds chirp in the background.
  • A cricket chirps in the foreground while birds are chirping in the background.
  • A loud, heavy wind blows steadily in the background while crickets and birds are chirping.
  • Outdoor noise with birds chirping and communicating with each other.
  • While crickets and birds are chirping, a loud and heavy wind blows steadily in the background.

And so on.

Let's go through a super simplified example of Language based audio retrieval.

Starting with the Dataset:

class AudioDataset(Dataset):
def __init__(self, captions: pd.DataFrame):
self.captions = captions

def __len__(self):
return len(self.captions)

def __getitem__(self, idx):
item = self.captions.iloc[idx]

audio, _, duration, path = get_audio_tensor(item["file_name"])

caption_columns = [
"caption_1",
"caption_2",
"caption_3",
"caption_4",
"caption_5",
]
captions = item[caption_columns].tolist()

return DataItem(audio=audio, duration=duration, captions=captions, path=path)

This is very self explanatory. We just load the audios and CSV into a Pytorch dataset. However keep in mind that the audio has to be a tensor.

Next we look into the Text encoders:

class TextEncoder(nn.Module):
def __init__(self, encoding_model, tokenizer):
super().__init__()
self.tokenizer = tokenizer
self.encoding_model = encoding_model

def forward(self, captions):
tokenized = self.tokenizer(
captions, return_tensors="pt", padding=True, truncation=True
)
outputs = self.encoding_model(
input_ids=tokenized["input_ids"].to(device),
attention_mask=tokenized["attention_mask"].to(device),
)

normalized_outputs = F.normalize(outputs.last_hidden_state, dim=-1)

return normalized_outputs

Here whats important is that we normalize the output of the model. The audio encoder also looks very similar:

class AudioEncoder(nn.Module):
def __init__(self, encoding_model, feature_extractor):
super().__init__()
self.feature_extractor = feature_extractor
self.encoding_model = encoding_model

def forward(self, waveforms):

inputs = self.feature_extractor(
waveforms, return_tensors="pt", padding=True, sampling_rate=16_000
).input_values.to(device)

inputs = inputs.squeeze(0).to(device)

embeddings = self.encoding_model(inputs)

normalized_outputs = F.normalize(embeddings.last_hidden_state, dim=-1)

return normalized_outputs

We can now put them in the main model:

class RetrievalModel(nn.Module):
def __init__(self, audio_encoder, text_encoder, loss_fn, proj_dim: int = 256):
super().__init__()
self.audio_encoder = audio_encoder
self.text_encoder = text_encoder
self.loss_fn = loss_fn

text_hidden = text_encoder.encoding_model.config.hidden_size
audio_hidden = audio_encoder.encoding_model.config.hidden_size

self.text_proj = nn.Linear(text_hidden, proj_dim).to(device)
self.audio_proj = nn.Linear(audio_hidden, proj_dim).to(device)

def forward(self, audio_waveforms, captions):
audio_embeddings = self.audio_encoder(audio_waveforms)
text_embeddings = self.text_encoder(captions)

audio_embeddings = audio_embeddings.mean(dim=1).to(device)
text_embeddings = text_embeddings.mean(dim=1).to(device)

audio_embeddings = F.normalize(self.audio_proj(audio_embeddings), dim=-1).to(
device
)
text_embeddings = F.normalize(self.text_proj(text_embeddings), dim=-1).to(
device
)

return audio_embeddings, text_embeddings

def loss(self, aud_emb, txt_emb):
return self.loss_fn(aud_emb, txt_emb)

Here what is important is that we put them both into the same vecotr space. Here we initialized the dimentions into 256 where the vectors are projected into.

The loss we use is contrastive loss:

import torch.nn.functional as F

def contrastive_loss(aud_emb, txt_emb, margin=0.2):
"""
aud_emb: (B_audio, D)
txt_emb: (B_audio * S, D) where each of the B_audio audios has S captions
"""
B_audio, D = aud_emb.shape
B_text, _ = txt_emb.shape
assert B_text % B_audio == 0, "txt_emb must be a multiple of aud_emb in length"
S = B_text // B_audio

# 1) repeat each audio embedding S times -> shape (B_audio*S, D)
aud_rep = aud_emb.unsqueeze(1).expand(-1, S, -1).reshape(-1, D)

# 2) cosine‐sim matrix now square (B_audio*S)×(B_audio*S)
sims = aud_rep @ txt_emb.t() # (B_audio*S, B_audio*S)

# 3) positives on the diagonal
pos = sims.diag().unsqueeze(1) # (B_audio*S, 1)

# 4) hinge‐loss: max(0, margin + sims - pos)
loss = F.relu(margin + sims - pos)

# 5) zero out the diagonal so we don’t hinge on the true positives themselves
loss.fill_diagonal_(0)

# 6) final scalar
return loss.mean()

After we train the model, we can now using the model generate the embeddings:

dataset = AudioDataset(test_audios_df)

all_aud, all_txt = [], []

with torch.no_grad():
for i, item in enumerate(dataset):

print(f"Processing {i+1}/{len(dataset)}")

wavs = item.audio
caps = item.captions

wavs = wavs.to(device)
aud_emb, txt_emb = model(wavs, caps)

all_aud.append(aud_emb.detach().cpu())
all_txt.append(txt_emb.detach().cpu())

aud_mat = torch.cat(all_aud)
txt_mat = torch.cat(all_txt)

torch.save(aud_mat, "audio.pt")
torch.save(txt_mat, "text.pt")

After that we can take the computed embeddings and then calculate how well the model performed. To do so you can use the Recall@K metric. This just tracks how often the correct audio is in the first K samples.

ground_truth = create_caption_clip_mapping(test_audios_df)

aud = torch.load("audio.pt")
txt = torch.load("text.pt")
aud = torch.nn.functional.normalize(aud, dim=-1)
txt = torch.nn.functional.normalize(txt, dim=-1)

sim = txt @ aud.T

print(f"Audios shape: {aud.shape}, Text shape: {txt.shape}")
print("Similarity shape: ", sim.shape)


K = 1
vals, idxs = sim.topk(K, dim=1)
print("Top K indices shape: ", idxs.shape)


hits = (idxs == ground_truth.unsqueeze(1)).any(dim=1)

recall_at_k = hits.float().mean().item()
print(f"Recall@{K}: {recall_at_k:.3f}")

And that's kinda it. A very basic overview of how language based audio retrieval works.