我是靠谱客的博主 端庄御姐,最近开发中收集的这篇文章主要介绍Python sorted() 函数、Pytorch collate_fn函数、Python scatter_函数,Python 迭代器dataloader,觉得挺不错的,现在分享给大家,希望可以做个参考。

概述

编写不易如果觉得不错,麻烦关注一下~

Python sorted() 函数 

https://www.runoob.com/python/python-func-sorted.html

sort 与 sorted 区别:

sort 是应用在 list 上的方法,sorted 可以对所有可迭代的对象进行排序操作。

list 的 sort 方法返回的是对已经存在的列表进行操作,无返回值,而内建函数 sorted 方法返回的是一个新的 list,而不是在原来的基础上进行的操作。

sorted(iterable, cmp=None, key=None, reverse=False) 

>>>a = [5,7,6,3,4,1,2]
>>> b = sorted(a)       # 保留原列表
>>> a 
[5, 7, 6, 3, 4, 1, 2]
>>> b
[1, 2, 3, 4, 5, 6, 7]
 
>>> L=[('b',2),('a',1),('c',3),('d',4)]
>>> sorted(L, cmp=lambda x,y:cmp(x[1],y[1]))   # 利用cmp函数
[('a', 1), ('b', 2), ('c', 3), ('d', 4)]
>>> sorted(L, key=lambda x:x[1])               # 利用key
[('a', 1), ('b', 2), ('c', 3), ('d', 4)]
 
 
>>> students = [('john', 'A', 15), ('jane', 'B', 12), ('dave', 'B', 10)]
>>> sorted(students, key=lambda s: s[2])            # 按年龄排序
[('dave', 'B', 10), ('jane', 'B', 12), ('john', 'A', 15)]
 
>>> sorted(students, key=lambda s: s[2], reverse=True)       # 按降序
[('john', 'A', 15), ('jane', 'B', 12), ('dave', 'B', 10)]

 Pytorch collate_fn函数

https://blog.csdn.net/weixin_42028364/article/details/81675021

import torch
import torch.utils.data as Data
import numpy as np

test = np.array([0,1,2,3,4,5,6,7,8,9,10,11])

inputing = torch.tensor(np.array([test[i:i + 3] for i in range(10)]))
target = torch.tensor(np.array([test[i:i + 1] for i in range(10)]))

torch_dataset = Data.TensorDataset(inputing,target)
batch = 3

loader = Data.DataLoader(
    dataset=torch_dataset,
    batch_size=batch, # 批大小
    # 若dataset中的样本数不能被batch_size整除的话,最后剩余多少就使用多少
    collate_fn=lambda x:(
        torch.cat(
            [x[i][j].unsqueeze(0) for i in range(len(x))], 0
            ).unsqueeze(0) for j in range(len(x[0]))
        )
    )
#这里的lamda x 的x应该就是指的dataset torch_dataset。
# 此时len(x[0])指的就是第一个列表元素的元#组#的长度,即inputing 和target 这两维,len(x)指的就 
# 是列表长度为10.


for (i,j) in loader:
    print(i)
    print(j)
    print("------------------------")

tensor([[[0, 1, 2],
         [1, 2, 3],
         [2, 3, 4]]])
tensor([[[0],
         [1],
         [2]]])
------------------------
tensor([[[3, 4, 5],
         [4, 5, 6],
         [5, 6, 7]]])
tensor([[[3],
         [4],
         [5]]])
------------------------
tensor([[[ 6,  7,  8],
         [ 7,  8,  9],
         [ 8,  9, 10]]])
tensor([[[6],
         [7],
         [8]]])
------------------------
tensor([[[ 9, 10, 11]]])
tensor([[[9]]])
------------------------

    collate_fn=lambda x:(
        torch.cat(
            [x[i][0].unsqueeze(0) for i in range(len(x))], 0
            ).unsqueeze(0)
        )

#其余不变

for (i) in loader:
    print(i)
    #print(j)
    print("------------------------")

 tensor([[[0, 1, 2],
         [1, 2, 3],
         [2, 3, 4]]])
------------------------
tensor([[[3, 4, 5],
         [4, 5, 6],
         [5, 6, 7]]])
------------------------
tensor([[[ 6,  7,  8],
         [ 7,  8,  9],
         [ 8,  9, 10]]])
------------------------
tensor([[[ 9, 10, 11]]])
------------------------

Python scatter_函数 

 https://blog.csdn.net/weixin_45547563/article/details/105311543

https://zhuanlan.zhihu.com/p/271877960

scatter_()函数有三个参数

scatter_(dim, index, src)

其中dim指的是 在哪个维度进行索引,index指的是:用来进行索引的tensor

src指的是:用来scatter的源元素。

假设源tensor为x ,且x.shape=(2, 4),index为索引,且index.shape = (2,4)

给定一个x =[[1,2,3,4] ,[5,6,7,8]]

给定一个index = [[2,1,2,2],[0,2,1,1]]

然后

print(x)

print(index)

y = torch.zeros(3,4)

y = y.scatter_(0,torch.LongTensor([[2,1,2,2],[0,2,1,1]]),x)

print(y)

结果为:

tensor([[1, 2, 3, 4],

[5, 6, 7, 8]])

tensor([[2, 1, 2, 2],

[0, 2, 1, 1]])

tensor([[5, 0.0000, 0.0000, 0.0000],

[ 0.0000, 2, 7, 8],

[1, 6, 3, 4]])

我们可以简单理解为,按照行进行填充y。 y的填充数值是x 。x的填充规则是index 。而index 每个元素与x 一一对应。填充的位置即为index 指定的行号。而填充的列值保持与x 一致。也就是说 index 只是描述出x 的行号位置。所以说,x 的1和5 元素无论怎么填充都在第0列 

Python 迭代器 


numbers = [1, 2, 3,]

# 迭代器循环 for x in iterator
for i,n in enumerate(numbers):
    print(i,n) # 0,1 / 1,3 / 2,3
 


 

最后

以上就是端庄御姐为你收集整理的Python sorted() 函数、Pytorch collate_fn函数、Python scatter_函数,Python 迭代器dataloader的全部内容,希望文章能够帮你解决Python sorted() 函数、Pytorch collate_fn函数、Python scatter_函数,Python 迭代器dataloader所遇到的程序开发问题。

如果觉得靠谱客网站的内容还不错,欢迎将靠谱客网站推荐给程序员好友。

本图文内容来源于网友提供,作为学习参考使用,或来自网络收集整理,版权属于原作者所有。
点赞(51)

评论列表共有 0 条评论

立即
投稿
返回
顶部