Similarity Search with Embeddings#
Introduction#
Embeddings are a way to compress large amounts of information into a smaller set of features that represent meaningful semantics. Instead of raw pixel values, each location is represented by a dense vector that captures the semantic content of the landscape — making it possible to search for visually similar areas using vector operations. AlphaEarth Foundations (AEF) Embeddings is an openly available global dataset of satellite embeddings derived from multiple earth observation datasets, accessible via Source Cooperative. The aef-loader package provides a convenient Python interface to query and stream these embeddings without downloading the entire dataset.
This tutorial is an open-source adaptation of our Google Earth Engine community tutorial on Satellite Embedding Similarity Search. We replicate the same workflow using open-access data and open-source packages — xarray, dask, and scikit-learn.
Overview of the Task#
We will use AlphaEarth Foundations Embeddings to perform a similarity search for grain silos in Franklin County, Kansas, USA. Starting from a few known silo locations, we compute a reference embedding vector by averaging the embeddings at those points. We then compute the cosine similarity between this reference vector and every pixel in the county, and threshold the result to identify other areas with grain silos.
Grain silos (image: Wikipedia)
Input Layers:
AlphaEarth Embeddings (AEF) for year 2024, streamed via the
aef-loaderpackage.cb_2021_us_county_500k.zip: US county boundaries from the US Census Bureau, used to define the study area (Franklin County, Kansas).
Output:
cosine_similarity.tif: A GeoTIFF raster of per-pixel cosine similarity scores relative to the reference grain silo embedding.similar_pixels.gpkg: A GeoPackage of vectorized polygons for pixels that exceed the similarity threshold.
Data Credit:
AlphaEarth Foundations (AEF) Satellite Embeddings : The AlphaEarth Foundations Satellite Embedding dataset is produced by Google and Google DeepMind. Accessed from Source Cooperative.
County boundary data: US Census Bureau, 2021 Cartographic Boundary Files.
Running the Notebook:
The preferred way to run this notebook is on Google Colab.
Watch Video Walkthrough: Watch a detailed explanation of the workflow. 
Setup and Data Download#
The following blocks of code will install the required packages and download the datasets to your Colab environment.
%%capture
if 'google.colab' in str(get_ipython()):
!pip install rioxarray aef-loader dask[distributed] leafmap
import asyncio
import os
import dask.array as da
import geopandas as gpd
import leafmap.foliumap as leafmap
import numpy as np
import rioxarray as rxr
import xarray as xr
from aef_loader import AEFIndex, VirtualTiffReader, DataSource
from aef_loader.utils import dequantize_aef, reproject_datatree
from odc.geo.geobox import GeoBox
from pyproj import Transformer
from rasterio.features import shapes
from shapely.geometry import shape
from sklearn.metrics.pairwise import cosine_similarity
Setup a local Dask cluster. This distributes the computation across multiple workers on your computer.
from dask.distributed import Client
client = Client() # set up local cluster on the machine
client
If you are running this notebook in Colab, you will need to create and use a proxy URL to see the dashboard running on the local server.
if 'google.colab' in str(get_ipython()):
from google.colab import output
port_to_expose = 8787 # This is the default port for Dask dashboard
print(output.eval_js(f'google.colab.kernel.proxyPort({port_to_expose})'))
data_folder = 'data'
output_folder = 'output'
if not os.path.exists(data_folder):
os.mkdir(data_folder)
if not os.path.exists(output_folder):
os.mkdir(output_folder)
def download(url):
filename = os.path.join(data_folder, os.path.basename(url))
if not os.path.exists(filename):
from urllib.request import urlretrieve
local, _ = urlretrieve(url, filename)
print('Downloaded ' + local)
counties_file = 'cb_2021_us_county_500k.zip'
download('https://www2.census.gov/geo/tiger/GENZ2021/shp/' + counties_file)
Select the search region#
For this tutorial, we will map the grain silos in Franklin County, Kansas. First we will read the Zipped counties shapefile and apply a filter and select the polygon for this county.
counties_file_path = os.path.join(data_folder, counties_file)
counties_df = gpd.read_file(counties_file_path)
selected = counties_df[counties_df['GEOID'] == '20059'] # Franklin County, Kansas
selected.iloc[:, :6]
| STATEFP | COUNTYFP | COUNTYNS | AFFGEOID | GEOID | NAME | |
|---|---|---|---|---|---|---|
| 2730 | 20 | 059 | 00484998 | 0500000US20059 | 20059 | Franklin |
m = leafmap.Map(width=600, height=500)
m.add_tile_layer(
url='https://mt1.google.com/vt/lyrs=y&x={x}&y={y}&z={z}',
name='Google Satellite',
attribution='Google',
)
m.add_gdf(selected, layer_name="Selected County")
m.zoom_to_gdf(selected)
m
Create a odc.geo.geobox.GeoBox object which is a representation of the bounding box with a specific CRS and pixel grid.
bbox = list(selected.total_bounds)
target_crs = 'EPSG:3857'
# Transform bbox from EPSG:4326 to chosen crs
transformer = Transformer.from_crs('EPSG:4326', target_crs, always_xy=True)
x_min, y_min = transformer.transform(bbox[0], bbox[1])
x_max, y_max = transformer.transform(bbox[2], bbox[3])
geobox = GeoBox.from_bbox(
bbox=(x_min, y_min, x_max, y_max),
crs=target_crs,
resolution=10,
)
m = leafmap.Map(width=600, height=500)
m.add_tile_layer(
url='https://mt1.google.com/vt/lyrs=y&x={x}&y={y}&z={z}',
name='Google Satellite',
attribution='Google',
)
geobox.explore(map=m)
m
Load the Satellite Embeddings#
We now use the aef-loader package to load all the matching tiles of AlphaEarth Foundations Satellite Embeddings from Source Cooperative for the chosen year. This Lazily load the tiles as a XArray DataArray that we can fetch and process in chunks using Dask.
year = 2024
index = AEFIndex(source=DataSource.SOURCE_COOP)
await index.download()
# Query for tiles
tiles = await index.query(
bbox=bbox,
years=(year),
)
# Load tiles organized by UTM zone
async with VirtualTiffReader() as reader:
tree = await reader.open_tiles_by_zone(tiles)
# Depending on the region, there maybe multiple
# tiles spanning different UTM zones
# Reproject all the tiles to the target GeoBox
# with the chosen projection and pixel resolution
combined = reproject_datatree(tree, target_geobox=geobox)
embeddings = combined.embeddings
embeddings
<xarray.DataArray 'embeddings' (time: 1, band: 64, y: 4975, x: 5039)> Size: 2GB
dask.array<reproject, shape=(1, 64, 4975, 5039), dtype=int8, chunksize=(1, 5, 4975, 5039), chunktype=numpy.ndarray>
Coordinates:
* time (time) datetime64[ns] 8B 2024-01-01
* band (band) <U3 768B 'A00' 'A01' 'A02' 'A03' ... 'A61' 'A62' 'A63'
* y (y) float64 40kB 4.684e+06 4.684e+06 ... 4.635e+06 4.635e+06
* x (x) float64 40kB -1.063e+07 -1.063e+07 ... -1.058e+07
spatial_ref int32 4B 3857
Attributes:
citation: True
geog_angular_units: True
geog_citation: True
model_type: True
proj_linear_units: True
projected_type: True
raster_type: True
photometric_interpretation: 1
DESCRIPTION: A63
gdal_no_data: -128
nodata: -128
_FillValue: -128The embeddings are saved as 8-bit integer values to save space. We use the dequantize_aef helper function provided by aef-loader to convert them to the original 32-bit floating point values.
embeddings_year = embeddings.isel(time=0)
embeddings_float = dequantize_aef(embeddings_year)
embeddings_float
<xarray.DataArray 'embeddings' (band: 64, y: 4975, x: 5039)> Size: 6GB
dask.array<where, shape=(64, 4975, 5039), dtype=float32, chunksize=(5, 4975, 5039), chunktype=numpy.ndarray>
Coordinates:
* band (band) <U3 768B 'A00' 'A01' 'A02' 'A03' ... 'A61' 'A62' 'A63'
* y (y) float64 40kB 4.684e+06 4.684e+06 ... 4.635e+06 4.635e+06
* x (x) float64 40kB -1.063e+07 -1.063e+07 ... -1.058e+07
time datetime64[ns] 8B 2024-01-01
spatial_ref int32 4B 3857
Attributes: (12/14)
citation: True
geog_angular_units: True
geog_citation: True
model_type: True
proj_linear_units: True
projected_type: True
... ...
DESCRIPTION: A63
gdal_no_data: -128
nodata: nan
_FillValue: nan
units: embedding
dequantized: TrueSelect reference location(s)#
We pick locations of one or more grain silos. You can use the high-resolution Google Satellite imagery to find these coordinates. For this tutorial, we have selected 3 reference locations. These will be used to extract the embedding vectors from embedding DataArray.
target_location1 = (-95.18616479385629, 38.54715519758577)
target_location2 = (-95.34468619878159, 38.59339901996762)
target_location3 = (-95.34280239688128, 38.56233059960432)
m = leafmap.Map(width=600, height=500)
m.add_tile_layer(
url='https://mt1.google.com/vt/lyrs=y&x={x}&y={y}&z={z}',
name='Google Satellite',
attribution='Google',
)
for i, loc in enumerate([target_location1, target_location2, target_location3], 1):
m.add_marker(location=(loc[1], loc[0]), popup=f'Target {i}')
# Zoom to one of the selected target locations
m.set_center(target_location1[0], target_location1[1], zoom=17)
m
Extract Embeddings at Reference Locations#
We now extract embeddings from all 3 target locations and compute mean embedding that will be used to calculate similarity
# Extract embeddings from exactly 3 target locations using vectorized indexing
target_locations = [target_location1, target_location2, target_location3]
x_coords, y_coords = zip(*[
transformer.transform(lon, lat)
for lon, lat in target_locations])
# Using xarray DataArrays for indexing ensures we get the specific pairs
# of points (3 pixels total) instead of a grid of 3x3 pixels
target_embeddings = embeddings_float.sel(
x=xr.DataArray(list(x_coords), dims='location'),
y=xr.DataArray(list(y_coords), dims='location'),
method='nearest'
)
%%time
target_embeddings = target_embeddings.compute()
# Calculate the mean across the 3 location points
mean_embedding = target_embeddings.mean(dim='location')
Calculate Similarity#
We use scikit-learn’s cosine_similarity() function to compute and find other pixels which have similar embeddings. Let’s test this function to calculate the cosine_similarity between embeddings for each of the reference location and the mean embeddings.
This function expects a 2D-array. i.e. a table with rows for samples and columns for features. So we need to use .reshape() to convert our mean array into the required shape. Similarly, the target_embeddings is an array of (64, 3), so we transpose it (.T) to get it in the correct shape.
mean_vec = mean_embedding.values.reshape(1, -1) # (1, 64)
target_vec = target_embeddings.values.T # (3, 64)
Since these are all similiar locations, we expect a very high score (close to 1) for each of them.
cosine_similarity(mean_vec, target_vec).ravel()
array([0.9211531, 0.9329779, 0.9698776], dtype=float32)
We can apply this function on all pixels of the embeddings_float array. We reshape it into a 2D array so that we have each pixel as a row and embeddings vector as the column.
# Get embeddings as dask array: (bands, y, x)
emb = embeddings_float.data
# Transpose to (y, x, bands) then reshape to (y*x, bands)
emb = da.moveaxis(emb, 0, -1) # (y, x, bands)
ny, nx, nb = emb.shape
emb_2d = emb.reshape(-1, nb) # (y*x, bands)
emb_2d
|
||||||||||||||||
# Rechunk so each block has ALL bands (axis 1 = single chunk)
emb_2d_rechunked = emb_2d.rechunk({0: 'auto', 1: -1})
emb_2d_rechunked
|
||||||||||||||||
# Compute cosine similarity between the mean embedding and every pixel
# mean_embedding is 1-D (64,); emb_2d is (ny*nx, 64)
# Compute in chunks to stay memory-friendly
def cosine_sim_block(block, target=mean_vec):
"""Compute cosine similarity for one chunk of pixels."""
return cosine_similarity(block, target).ravel() # (chunk_size,)
# Map over dask blocks – each block is (chunk_pixels, 64) → (chunk_pixels,)
similarity = emb_2d_rechunked.map_blocks(
cosine_sim_block,
dtype=np.float64,
drop_axis=1, # the bands axis is collapsed
)
similarity
|
||||||||||||||||
%%time
similarity_values = similarity.compute()
# Build an xarray DataArray with the same spatial coords as the input
similarity_da = xr.DataArray(
similarity_values.reshape(ny, nx), #reshape back to (ny, nx)
dims=['y', 'x'],
coords={
'y': embeddings_float.coords['y'].values,
'x': embeddings_float.coords['x'].values,
},
name='cosine_similarity',
)
# Copy CRS and spatial metadata from the original embeddings
similarity_da = similarity_da.rio.write_crs(embeddings_float.rio.crs)
similarity_da = similarity_da.rio.write_transform(embeddings_float.rio.transform())
# Clip the output to the selected region
geometry = selected.to_crs(target_crs).geometry
similarity_da = similarity_da.rio.clip(geometry)
# Apply threshold and convert matching pixels to polygons
from rasterio.features import shapes
from shapely.geometry import shape
import geopandas as gpd
threshold = 0.95
# Create a binary mask: 1 where similarity >= threshold, 0 elsewhere
mask = (similarity_da.values >= threshold).astype(np.uint8)
# Get the affine transform from the similarity raster
transform = similarity_da.rio.transform()
crs = similarity_da.rio.crs
# Vectorize: convert raster mask to polygon geometries
polygons = []
for geom, val in shapes(mask, mask=(mask == 1), transform=transform):
polygons.append(shape(geom))
# Create a GeoDataFrame
predicted_matches = gpd.GeoDataFrame(
geometry=polygons,
crs=crs,
)
m = leafmap.Map(width=600, height=500)
m.add_tile_layer(
url='https://mt1.google.com/vt/lyrs=y&x={x}&y={y}&z={z}',
name='Google Satellite',
attribution='Google',
)
for i, loc in enumerate([target_location1, target_location2, target_location3], 1):
m.add_marker(location=(loc[1], loc[0]), popup=f'Target {i}')
m.zoom_to_gdf(selected)
m.add_gdf(predicted_matches, layer_name="Similar Areas",
style={'color': 'red', 'fillColor': 'red', 'fillOpacity': 0.5})
m
Save the Results#
# Save to GeoTIFF
output_path = os.path.join(output_folder, 'cosine_similarity.tif')
similarity_da.rio.to_raster(output_path)
print(f'Saved similarity raster to {output_path}')
# Save to GeoPackage
vector_path = os.path.join(output_folder, 'predicted_matches.gpkg')
predicted_matches.to_file(vector_path, driver='GPKG')
print(f'Saved predicted matches to {vector_path}')
If you want to give feedback or share your experience with this tutorial, please comment below. (requires GitHub account)