[docs]classAxisBatch(Generic[TArray]):"""Batch the array along the given axis."""def__init__(self,a:TArray,/,*,axis:int|Sequence[int],size:int)->None:""" Batch the array along the given axis. Parameters ---------- a : TArray The array to batch. axis : int | Sequence[int] The axis to batch. size : int The size of the batch. Yields ------ TArray The batched arrays. Returns ------- Callable[[Sequence[TArray]], TArray] The function that concats the batched arrays. Usage ----- >>> a = ivy.arange(10) >>> b = AxisBatch(a, axis=0, size=3) >>> for x in b: >>> b.send(x * 2) >>> print(b.value) """ifisinstance(axis,int):axis=(axis,)self._axis=axisself._size=sizeself._axis_len=len(axis)xp=array_namespace(a)a=xp.moveaxis(a,axis,tuple(range(self._axis_len)))self._shape=a.shape# type: ignoreself._a=xp.reshape(a,(-1,*self._shape[self._axis_len:]))self._results:list[TArray]=[]def__iter__(self)->Iterator[TArray]:""" Yield the batched array. Yields ------ Iterator[TArray] The batched array. """foriinrange(0,self._a.shape[0],self._size):yieldself._a[i:i+self._size]def__len__(self)->int:""" Return the number of batches. Returns ------- int The number of batches. """# https://stackoverflow.com/questions/14822184/is-there-a-ceiling-equivalent-of-operator-in-python# return len(self._a) // self._sizereturn-(len(self._a)//-self._size)
[docs]defsend(self,result:TArray,/)->None:""" Add the batched array to the results. Parameters ---------- result : TArray The batched array. """self._results.append(result)
@propertydefvalue(self)->TArray:""" Return the concentrated batched arrays. Returns ------- TArray The concentrated batched arrays. """xp=array_namespace(self._results[0])result=xp.concat(self._results,axis=0)returnxp.moveaxis(result,tuple(range(self._axis_len)),self._axis)