diff --git a/mastodon/Mastodon.py b/mastodon/Mastodon.py index 5171d9b..e616aa3 100644 --- a/mastodon/Mastodon.py +++ b/mastodon/Mastodon.py @@ -1226,20 +1226,22 @@ class Mastodon: params_initial = locals() # Load avatar, if specified - if avatar_mime_type is None and os.path.isfile(avatar): - avatar_mime_type = mimetypes.guess_type(avatar)[0] - avatar = open(avatar, 'rb') - - if (not avatar is None and avatar_mime_type is None): - raise MastodonIllegalArgumentError('Could not determine mime type or data passed directly without mime type.') + if not avatar is None: + if avatar_mime_type is None and os.path.isfile(avatar): + avatar_mime_type = mimetypes.guess_type(avatar)[0] + avatar = open(avatar, 'rb') + + if avatar_mime_type is None: + raise MastodonIllegalArgumentError('Could not determine mime type or data passed directly without mime type.') # Load header, if specified - if header_mime_type is None and os.path.isfile(header): - header_mime_type = mimetypes.guess_type(header)[0] - header = open(header, 'rb') - - if (not header is None and header_mime_type is None): - raise MastodonIllegalArgumentError('Could not determine mime type or data passed directly without mime type.') + if not header is None: + if header_mime_type is None and os.path.isfile(header): + header_mime_type = mimetypes.guess_type(header)[0] + header = open(header, 'rb') + + if header_mime_type is None: + raise MastodonIllegalArgumentError('Could not determine mime type or data passed directly without mime type.') # Clean up params for param in ["avatar", "avatar_mime_type", "header", "header_mime_type"]: @@ -1252,7 +1254,7 @@ class Mastodon: avatar_file_name = "mastodonpyupload_" + mimetypes.guess_extension(avatar_mime_type) files["avatar"] = (avatar_file_name, avatar, avatar_mime_type) if not header is None: - header_file_name = "mastodonpyupload_" + mimetypes.guess_extension(avatar_mime_type) + header_file_name = "mastodonpyupload_" + mimetypes.guess_extension(header_mime_type) files["header"] = (header_file_name, header, header_mime_type) params = self.__generate_params(params_initial) diff --git a/tests/test_account.py b/tests/test_account.py index 240abca..2716ad1 100644 --- a/tests/test_account.py +++ b/tests/test_account.py @@ -92,3 +92,23 @@ def test_account_update_credentials(api): header = image, header_mime_type = "image/jpeg") assert account + +@pytest.mark.vcr(match_on=['path']) +def test_account_update_credentials_no_header(api): + account = api.account_update_credentials( + display_name='John Lennon', + note='I walk funny', + avatar = "tests/image.jpg") + assert account + +@pytest.mark.vcr(match_on=['path']) +def test_account_update_credentials_no_avatar(api): + with open('tests/image.jpg', 'rb') as f: + image = f.read() + + account = api.account_update_credentials( + display_name='John Lennon', + note='I walk funny', + header = image, + header_mime_type = "image/jpeg") + assert account