Hi @729009982,
I am sorry, I’ve only worked on semantic segmentation so far (getting to instance segmentation as well ), I understand your problem includes also bounding boxes? (hence the transform on the label?). So you’ll probably need to write custom transformations too, that are deterministic, i.e. you’ll provide the random number manually, the same for all three objects (img_clean, img_noisy, label). Where by label I understand is something different, like BBox information too?
One thing I note in the code is that the transform is being applied to some img, label
however nowhere these are declared in the ___getitem__(self,idx)
(emphasis on double underscore), function. I don’t know which transform you are using, but I assume you should apply a deterministic transform to both the noisy and clean images. Something like
class MyDoubleDataset(gluon.nn.Dataset):
def __init__(self, some_arguments):
self._items = #something
# Read dataset, create in anyway you want the corresponding noisy images
def __len__(self):
# ....... as before
# don't forget double underscore!
def __getitem__(self, idx):
img_id = self._items[idx]
clean_img_path = self.clean_image_path.format(*img_id)
noise_im_path = self.noise_image_path.format(*img_id)
label = self._label_cache[idx] if self._label_cache else self._load_label(idx)
clean_img = mx.image.imread(clean_img_path, 1)
noise_img = mx.image.imread(noise_img_path,1)
# Here is a problem
if self._transform_img is not None: # Maybe add in the if statement also transform for label? I assume its bbox?
random_number = # Get here some random number
trans_img_clean = self._transform_img(clean_img, random_number)
trans_img_noise = self._transform_img(noise_img, random_number)
trans_label = self._transform_label (label,random_number)
return trans_img_clean, trans_img_noise, trans_label
return clean_img, noise_img, label
Also make sure that now you are spitting out 3 outputs, so you’ll need to change the definitions on how you use this in your network:
datagen = gluon.data.DataLoader(dataset, ...., shuffle = True, ...)
for i, data in enumerate(datagen);
batch_clean, batch_noise, batch_label = data
# do stuff with network, I don't know how you feed into network images.
# something like?
out_clean = net(batch_clean)
out_noise = net(batch_noise) # ?
break
does this help?